124 deadnochargeweight = 500.0
125 deadlowchargeweight = 100.0
126 deadhighchargeweight = 100.0
127 deadhighestchargeweight = 100.0
130 goodch = (input != 0)
131 for idx, col
in enumerate(goodch[0, 0, :, :].T):
133 goodch[0, 0, :, idx] = 1
134 deadch = (~goodch).
float()
135 goodch = goodch.float()
138 predictgood = goodch * predict
139 adcgood = goodch * true
140 totnondead = goodch.sum().
float()
141 if (totnondead == 0):
143 nondeadloss = (L1loss(predictgood, adcgood)*nondeadweight)/totnondead
146 deadchhighestcharge = deadch * (true.abs() > 70).
float()
147 predictdeadhighestcharge = predict*deadchhighestcharge
148 adcdeadhighestcharge = true*deadchhighestcharge
149 totdeadhighestcharge = deadchhighestcharge.sum().
float()
150 if (totdeadhighestcharge == 0):
151 totdeadhighestcharge = 1.0
152 deadhighestchargeloss = (L1loss(predictdeadhighestcharge,adcdeadhighestcharge)*deadhighestchargeweight)/totdeadhighestcharge
155 deadchhighcharge = deadch * (true.abs() > 40).
float()*(true.abs() < 70).
float()
156 predictdeadhighcharge = predict*deadchhighcharge
157 adcdeadhighcharge = true*deadchhighcharge
158 totdeadhighcharge = deadchhighcharge.sum().
float()
159 if (totdeadhighcharge == 0):
160 totdeadhighcharge = 1.0
161 deadhighchargeloss = (L1loss(predictdeadhighcharge,adcdeadhighcharge)*deadhighchargeweight)/totdeadhighcharge
164 deadchlowcharge = deadch * (true.abs() > 10).
float() *(true.abs() < 40).
float()
165 deadchnocharge = deadch * (true.abs() < 10).
float()
166 predictdeadlowcharge = predict*deadchlowcharge
167 adcdeadlowcharge = true*deadchlowcharge
168 totdeadlowcharge = deadchlowcharge.sum().
float()
169 if (totdeadlowcharge == 0):
170 totdeadlowcharge = 1.0
171 deadlowchargeloss = (L1loss(predictdeadlowcharge,adcdeadlowcharge)*deadlowchargeweight)/totdeadlowcharge
174 predictdeadnocharge = predict*deadchnocharge
175 adcdeadnocharge = true*deadchnocharge
176 totdeadnocharge = deadchnocharge.sum().
float()
177 if (totdeadnocharge == 0):
178 totdeadnocharge = 1.0
179 deadnochargeloss = (L1loss(predictdeadnocharge,adcdeadnocharge)*deadnochargeweight)/totdeadnocharge
181 totloss = nondeadloss + deadnochargeloss + deadlowchargeloss +deadhighchargeloss+deadhighestchargeloss
182 return nondeadloss, deadnochargeloss, deadlowchargeloss, deadhighchargeloss,deadhighestchargeloss, totloss
def _assert_no_grad(variable)
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
def forward(self, predict, true, input)