3 This loss mimics nividia's pixelwise loss for holes (L1) 4 used in the infill network 6 Loss function designed for a network that infills tracks. Separate functions for collection views 7 (unipolar signals) and induction views (bipolar signal) 8 Taken from github.com/NuTufts/sparse_infill. 17 assert not variable.requires_grad, \
18 "nn criterions don't compute the gradient w.r.t. targets - please " \
19 "mark these variables as not requiring gradients" 23 def __init__(self, weight=None, size_average=False, ignore_index=-100):
24 super(InfillLossInduction, self).
__init__(weight, size_average)
35 deadnochargeweight = 500.0
36 deadlowchargeweight = 100.0
37 deadhighchargeweight = 100.0
38 deadhighestchargeweight = 100.0
42 for idx, col
in enumerate(goodch[0, 0, :, :].T):
44 goodch[0, 0, :, idx] = 1
45 deadch = (~goodch).
float()
46 goodch = goodch.float()
49 predictgood = goodch * predict
50 adcgood = goodch * true
51 totnondead = goodch.sum().
float()
54 nondeadloss = (L1loss(predictgood, adcgood)*nondeadweight)/totnondead
57 deadchhighestcharge = deadch * (true.abs() > 70).
float()
58 predictdeadhighestcharge = predict*deadchhighestcharge
59 adcdeadhighestcharge = true*deadchhighestcharge
60 totdeadhighestcharge = deadchhighestcharge.sum().
float()
61 if (totdeadhighestcharge == 0):
62 totdeadhighestcharge = 1.0
63 deadhighestchargeloss = (L1loss(predictdeadhighestcharge,adcdeadhighestcharge)*
64 deadhighestchargeweight)/totdeadhighestcharge
67 deadchhighcharge = deadch * (true.abs() > 40).
float() * (true.abs() < 70).
float()
68 predictdeadhighcharge = predict*deadchhighcharge
69 adcdeadhighcharge = true*deadchhighcharge
70 totdeadhighcharge = deadchhighcharge.sum().
float()
71 if (totdeadhighcharge == 0):
72 totdeadhighcharge = 1.0
73 deadhighchargeloss = (L1loss(predictdeadhighcharge,adcdeadhighcharge)*
74 deadhighchargeweight)/totdeadhighcharge
77 deadchlowcharge = deadch * (true.abs() > 10).
float() * (true.abs() < 40).
float()
78 deadchnocharge = deadch * (true.abs() < 10).
float()
82 dead_cols = torch.nonzero(deadch)
84 dead_cols = torch.unique(torch.nonzero(img[0, :, :])[:, 1])
86 diffs = img[0, :, col].roll(1) - img[0, :, col].roll(-1)
89 deadchlowcharge_hidden = (diffs.abs() > 20).
float() * deadchnocharge[i, 0, :, col].
float()
90 deadchnocharge[i, 0, :, col] -= deadchlowcharge_hidden
91 deadchlowcharge[i, 0, :, col] += deadchlowcharge_hidden
93 predictdeadlowcharge = predict*deadchlowcharge
94 adcdeadlowcharge = true*deadchlowcharge
95 totdeadlowcharge = deadchlowcharge.sum().
float()
96 if (totdeadlowcharge == 0):
97 totdeadlowcharge = 1.0
98 deadlowchargeloss = (L1loss(predictdeadlowcharge,adcdeadlowcharge)*deadlowchargeweight)/totdeadlowcharge
100 predictdeadnocharge = predict*deadchnocharge
101 adcdeadnocharge = true*deadchnocharge
102 totdeadnocharge = deadchnocharge.sum().
float()
103 if (totdeadnocharge == 0):
104 totdeadnocharge = 1.0
105 deadnochargeloss = (L1loss(predictdeadnocharge,adcdeadnocharge)*deadnochargeweight)/totdeadnocharge
107 totloss = nondeadloss + deadnochargeloss + deadlowchargeloss +deadhighchargeloss+deadhighestchargeloss
108 return nondeadloss, deadnochargeloss, deadlowchargeloss, deadhighchargeloss,deadhighestchargeloss, totloss
112 def __init__(self, weight=None, size_average=False, ignore_index=-100):
113 super(InfillLossCollection, self).
__init__(weight, size_average)
124 deadnochargeweight = 500.0
125 deadlowchargeweight = 100.0
126 deadhighchargeweight = 100.0
127 deadhighestchargeweight = 100.0
130 goodch = (input != 0)
131 for idx, col
in enumerate(goodch[0, 0, :, :].T):
133 goodch[0, 0, :, idx] = 1
134 deadch = (~goodch).
float()
135 goodch = goodch.float()
138 predictgood = goodch * predict
139 adcgood = goodch * true
140 totnondead = goodch.sum().
float()
141 if (totnondead == 0):
143 nondeadloss = (L1loss(predictgood, adcgood)*nondeadweight)/totnondead
146 deadchhighestcharge = deadch * (true.abs() > 70).
float()
147 predictdeadhighestcharge = predict*deadchhighestcharge
148 adcdeadhighestcharge = true*deadchhighestcharge
149 totdeadhighestcharge = deadchhighestcharge.sum().
float()
150 if (totdeadhighestcharge == 0):
151 totdeadhighestcharge = 1.0
152 deadhighestchargeloss = (L1loss(predictdeadhighestcharge,adcdeadhighestcharge)*deadhighestchargeweight)/totdeadhighestcharge
155 deadchhighcharge = deadch * (true.abs() > 40).
float()*(true.abs() < 70).
float()
156 predictdeadhighcharge = predict*deadchhighcharge
157 adcdeadhighcharge = true*deadchhighcharge
158 totdeadhighcharge = deadchhighcharge.sum().
float()
159 if (totdeadhighcharge == 0):
160 totdeadhighcharge = 1.0
161 deadhighchargeloss = (L1loss(predictdeadhighcharge,adcdeadhighcharge)*deadhighchargeweight)/totdeadhighcharge
164 deadchlowcharge = deadch * (true.abs() > 10).
float() *(true.abs() < 40).
float()
165 deadchnocharge = deadch * (true.abs() < 10).
float()
166 predictdeadlowcharge = predict*deadchlowcharge
167 adcdeadlowcharge = true*deadchlowcharge
168 totdeadlowcharge = deadchlowcharge.sum().
float()
169 if (totdeadlowcharge == 0):
170 totdeadlowcharge = 1.0
171 deadlowchargeloss = (L1loss(predictdeadlowcharge,adcdeadlowcharge)*deadlowchargeweight)/totdeadlowcharge
174 predictdeadnocharge = predict*deadchnocharge
175 adcdeadnocharge = true*deadchnocharge
176 totdeadnocharge = deadchnocharge.sum().
float()
177 if (totdeadnocharge == 0):
178 totdeadnocharge = 1.0
179 deadnochargeloss = (L1loss(predictdeadnocharge,adcdeadnocharge)*deadnochargeweight)/totdeadnocharge
181 totloss = nondeadloss + deadnochargeloss + deadlowchargeloss +deadhighchargeloss+deadhighestchargeloss
182 return nondeadloss, deadnochargeloss, deadlowchargeloss, deadhighchargeloss,deadhighestchargeloss, totloss
def _assert_no_grad(variable)
def __init__(self, weight=None, size_average=False, ignore_index=-100)
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
def forward(self, predict, true, input)
def forward(self, predict, true, input)
def __init__(self, weight=None, size_average=False, ignore_index=-100)