infill_loss.py
Go to the documentation of this file.
1 """
2 HolePixelLoss
3 This loss mimics nividia's pixelwise loss for holes (L1)
4 used in the infill network
5 
6 Loss function designed for a network that infills tracks. Separate functions for collection views
7 (unipolar signals) and induction views (bipolar signal)
8 Taken from github.com/NuTufts/sparse_infill.
9 """
10 
11 import torch
12 import torch.nn as nn
13 
14 
15 # Taken from torch.nn.modules.loss
16 def _assert_no_grad(variable):
17  assert not variable.requires_grad, \
18  "nn criterions don't compute the gradient w.r.t. targets - please " \
19  "mark these variables as not requiring gradients"
20 
21 
22 class InfillLossInduction(nn.modules.loss._WeightedLoss):
23  def __init__(self, weight=None, size_average=False, ignore_index=-100):
24  super(InfillLossInduction, self).__init__(weight, size_average)
25  self.ignore_index = ignore_index
26  self.reduce = False
27  self.size_average = size_average
28 
29  def forward(self, predict, true, input):
30  _assert_no_grad(true)
31 
32  # Want three losses: non-dead, dead w/o charge, dead w/ charge
33  L1loss=torch.nn.L1Loss(self.size_average)
34  nondeadweight = 1.0
35  deadnochargeweight = 500.0
36  deadlowchargeweight = 100.0
37  deadhighchargeweight = 100.0
38  deadhighestchargeweight = 100.0
39 
40  # Identify dead channels
41  goodch = (input != 0)
42  for idx, col in enumerate(goodch[0, 0, :, :].T):
43  if col.sum():
44  goodch[0, 0, :, idx] = 1
45  deadch = (~goodch).float()
46  goodch = goodch.float()
47 
48  # Compute non-dead loss
49  predictgood = goodch * predict
50  adcgood = goodch * true
51  totnondead = goodch.sum().float()
52  if (totnondead == 0):
53  totnondead = 1.0
54  nondeadloss = (L1loss(predictgood, adcgood)*nondeadweight)/totnondead
55 
56  # Compute dead with highest true charge loss
57  deadchhighestcharge = deadch * (true.abs() > 70).float()
58  predictdeadhighestcharge = predict*deadchhighestcharge
59  adcdeadhighestcharge = true*deadchhighestcharge
60  totdeadhighestcharge = deadchhighestcharge.sum().float()
61  if (totdeadhighestcharge == 0):
62  totdeadhighestcharge = 1.0
63  deadhighestchargeloss = (L1loss(predictdeadhighestcharge,adcdeadhighestcharge)*
64  deadhighestchargeweight)/totdeadhighestcharge
65 
66  # Compute dead with high true charge loss
67  deadchhighcharge = deadch * (true.abs() > 40).float() * (true.abs() < 70).float()
68  predictdeadhighcharge = predict*deadchhighcharge
69  adcdeadhighcharge = true*deadchhighcharge
70  totdeadhighcharge = deadchhighcharge.sum().float()
71  if (totdeadhighcharge == 0):
72  totdeadhighcharge = 1.0
73  deadhighchargeloss = (L1loss(predictdeadhighcharge,adcdeadhighcharge)*
74  deadhighchargeweight)/totdeadhighcharge
75 
76  # Compute dead with low true chanrge and dead with no (< 10) true charge
77  deadchlowcharge = deadch * (true.abs() > 10).float() * (true.abs() < 40).float()
78  deadchnocharge = deadch * (true.abs() < 10).float()
79 
80  # < 10 charge regions where the bipolar signal crosses zero should be low true charge rather than dead.
81  # Identify them and move them from deadnocharge to deadlowcharge
82  dead_cols = torch.nonzero(deadch)
83  for i, img in enumerate(true * deadch):
84  dead_cols = torch.unique(torch.nonzero(img[0, :, :])[:, 1])
85  for col in dead_cols:
86  diffs = img[0, :, col].roll(1) - img[0, :, col].roll(-1)
87  diffs[0] = 0
88  diffs[-1] = 0
89  deadchlowcharge_hidden = (diffs.abs() > 20).float() * deadchnocharge[i, 0, :, col].float()
90  deadchnocharge[i, 0, :, col] -= deadchlowcharge_hidden
91  deadchlowcharge[i, 0, :, col] += deadchlowcharge_hidden
92 
93  predictdeadlowcharge = predict*deadchlowcharge
94  adcdeadlowcharge = true*deadchlowcharge
95  totdeadlowcharge = deadchlowcharge.sum().float()
96  if (totdeadlowcharge == 0):
97  totdeadlowcharge = 1.0
98  deadlowchargeloss = (L1loss(predictdeadlowcharge,adcdeadlowcharge)*deadlowchargeweight)/totdeadlowcharge
99 
100  predictdeadnocharge = predict*deadchnocharge
101  adcdeadnocharge = true*deadchnocharge
102  totdeadnocharge = deadchnocharge.sum().float()
103  if (totdeadnocharge == 0):
104  totdeadnocharge = 1.0
105  deadnochargeloss = (L1loss(predictdeadnocharge,adcdeadnocharge)*deadnochargeweight)/totdeadnocharge
106 
107  totloss = nondeadloss + deadnochargeloss + deadlowchargeloss +deadhighchargeloss+deadhighestchargeloss
108  return nondeadloss, deadnochargeloss, deadlowchargeloss, deadhighchargeloss,deadhighestchargeloss, totloss
109 
110 
111 class InfillLossCollection(nn.modules.loss._WeightedLoss):
112  def __init__(self, weight=None, size_average=False, ignore_index=-100):
113  super(InfillLossCollection, self).__init__(weight, size_average)
114  self.ignore_index = ignore_index
115  self.reduce = False
116  self.size_average = size_average
117 
118  def forward(self, predict, true, input):
119  _assert_no_grad(true)
120 
121  # Want three losses: non-dead, dead w/o charge, dead w/ charge
122  L1loss=torch.nn.L1Loss(self.size_average)
123  nondeadweight = 1.0
124  deadnochargeweight = 500.0
125  deadlowchargeweight = 100.0
126  deadhighchargeweight = 100.0
127  deadhighestchargeweight = 100.0
128 
129  # Identify dead channels
130  goodch = (input != 0)
131  for idx, col in enumerate(goodch[0, 0, :, :].T):
132  if col.sum():
133  goodch[0, 0, :, idx] = 1
134  deadch = (~goodch).float()
135  goodch = goodch.float()
136 
137  # Compute non-dead loss
138  predictgood = goodch * predict
139  adcgood = goodch * true
140  totnondead = goodch.sum().float()
141  if (totnondead == 0):
142  totnondead = 1.0
143  nondeadloss = (L1loss(predictgood, adcgood)*nondeadweight)/totnondead
144 
145  # Compute dead with highest true charge loss
146  deadchhighestcharge = deadch * (true.abs() > 70).float()
147  predictdeadhighestcharge = predict*deadchhighestcharge
148  adcdeadhighestcharge = true*deadchhighestcharge
149  totdeadhighestcharge = deadchhighestcharge.sum().float()
150  if (totdeadhighestcharge == 0):
151  totdeadhighestcharge = 1.0
152  deadhighestchargeloss = (L1loss(predictdeadhighestcharge,adcdeadhighestcharge)*deadhighestchargeweight)/totdeadhighestcharge
153 
154  # Compute dead with high true charge loss
155  deadchhighcharge = deadch * (true.abs() > 40).float()*(true.abs() < 70).float()
156  predictdeadhighcharge = predict*deadchhighcharge
157  adcdeadhighcharge = true*deadchhighcharge
158  totdeadhighcharge = deadchhighcharge.sum().float()
159  if (totdeadhighcharge == 0):
160  totdeadhighcharge = 1.0
161  deadhighchargeloss = (L1loss(predictdeadhighcharge,adcdeadhighcharge)*deadhighchargeweight)/totdeadhighcharge
162 
163  # Compute dead with low true charge loss
164  deadchlowcharge = deadch * (true.abs() > 10).float() *(true.abs() < 40).float()
165  deadchnocharge = deadch * (true.abs() < 10).float()
166  predictdeadlowcharge = predict*deadchlowcharge
167  adcdeadlowcharge = true*deadchlowcharge
168  totdeadlowcharge = deadchlowcharge.sum().float()
169  if (totdeadlowcharge == 0):
170  totdeadlowcharge = 1.0
171  deadlowchargeloss = (L1loss(predictdeadlowcharge,adcdeadlowcharge)*deadlowchargeweight)/totdeadlowcharge
172 
173  # Compute dead with no (< 10) true charge loss
174  predictdeadnocharge = predict*deadchnocharge
175  adcdeadnocharge = true*deadchnocharge
176  totdeadnocharge = deadchnocharge.sum().float()
177  if (totdeadnocharge == 0):
178  totdeadnocharge = 1.0
179  deadnochargeloss = (L1loss(predictdeadnocharge,adcdeadnocharge)*deadnochargeweight)/totdeadnocharge
180 
181  totloss = nondeadloss + deadnochargeloss + deadlowchargeloss +deadhighchargeloss+deadhighestchargeloss
182  return nondeadloss, deadnochargeloss, deadlowchargeloss, deadhighchargeloss,deadhighestchargeloss, totloss
def _assert_no_grad(variable)
Definition: infill_loss.py:16
def __init__(self, weight=None, size_average=False, ignore_index=-100)
Definition: infill_loss.py:112
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def forward(self, predict, true, input)
Definition: infill_loss.py:118
def forward(self, predict, true, input)
Definition: infill_loss.py:29
def __init__(self, weight=None, size_average=False, ignore_index=-100)
Definition: infill_loss.py:23