35 deadnochargeweight = 500.0
36 deadlowchargeweight = 100.0
37 deadhighchargeweight = 100.0
38 deadhighestchargeweight = 100.0
42 for idx, col
in enumerate(goodch[0, 0, :, :].T):
44 goodch[0, 0, :, idx] = 1
45 deadch = (~goodch).
float()
46 goodch = goodch.float()
49 predictgood = goodch * predict
50 adcgood = goodch * true
51 totnondead = goodch.sum().
float()
54 nondeadloss = (L1loss(predictgood, adcgood)*nondeadweight)/totnondead
57 deadchhighestcharge = deadch * (true.abs() > 70).
float()
58 predictdeadhighestcharge = predict*deadchhighestcharge
59 adcdeadhighestcharge = true*deadchhighestcharge
60 totdeadhighestcharge = deadchhighestcharge.sum().
float()
61 if (totdeadhighestcharge == 0):
62 totdeadhighestcharge = 1.0
63 deadhighestchargeloss = (L1loss(predictdeadhighestcharge,adcdeadhighestcharge)*
64 deadhighestchargeweight)/totdeadhighestcharge
67 deadchhighcharge = deadch * (true.abs() > 40).
float() * (true.abs() < 70).
float()
68 predictdeadhighcharge = predict*deadchhighcharge
69 adcdeadhighcharge = true*deadchhighcharge
70 totdeadhighcharge = deadchhighcharge.sum().
float()
71 if (totdeadhighcharge == 0):
72 totdeadhighcharge = 1.0
73 deadhighchargeloss = (L1loss(predictdeadhighcharge,adcdeadhighcharge)*
74 deadhighchargeweight)/totdeadhighcharge
77 deadchlowcharge = deadch * (true.abs() > 10).
float() * (true.abs() < 40).
float()
78 deadchnocharge = deadch * (true.abs() < 10).
float()
82 dead_cols = torch.nonzero(deadch)
84 dead_cols = torch.unique(torch.nonzero(img[0, :, :])[:, 1])
86 diffs = img[0, :, col].roll(1) - img[0, :, col].roll(-1)
89 deadchlowcharge_hidden = (diffs.abs() > 20).
float() * deadchnocharge[i, 0, :, col].
float()
90 deadchnocharge[i, 0, :, col] -= deadchlowcharge_hidden
91 deadchlowcharge[i, 0, :, col] += deadchlowcharge_hidden
93 predictdeadlowcharge = predict*deadchlowcharge
94 adcdeadlowcharge = true*deadchlowcharge
95 totdeadlowcharge = deadchlowcharge.sum().
float()
96 if (totdeadlowcharge == 0):
97 totdeadlowcharge = 1.0
98 deadlowchargeloss = (L1loss(predictdeadlowcharge,adcdeadlowcharge)*deadlowchargeweight)/totdeadlowcharge
100 predictdeadnocharge = predict*deadchnocharge
101 adcdeadnocharge = true*deadchnocharge
102 totdeadnocharge = deadchnocharge.sum().
float()
103 if (totdeadnocharge == 0):
104 totdeadnocharge = 1.0
105 deadnochargeloss = (L1loss(predictdeadnocharge,adcdeadnocharge)*deadnochargeweight)/totdeadnocharge
107 totloss = nondeadloss + deadnochargeloss + deadlowchargeloss +deadhighchargeloss+deadhighestchargeloss
108 return nondeadloss, deadnochargeloss, deadlowchargeloss, deadhighchargeloss,deadhighestchargeloss, totloss
def _assert_no_grad(variable)
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
def forward(self, predict, true, input)