3 This loss mimics nividia's pixelwise loss for holes (L1)     4 used in the infill network     6 Loss function designed for a network that infills tracks. Separate functions for collection views     7 (unipolar signals) and induction views (bipolar signal)     8 Taken from github.com/NuTufts/sparse_infill.    17     assert not variable.requires_grad, \
    18         "nn criterions don't compute the gradient w.r.t. targets - please " \
    19         "mark these variables as not requiring gradients"    23     def __init__(self, weight=None, size_average=False, ignore_index=-100):
    24         super(InfillLossInduction, self).
__init__(weight, size_average)
    35         deadnochargeweight = 500.0
    36         deadlowchargeweight = 100.0
    37         deadhighchargeweight = 100.0
    38         deadhighestchargeweight = 100.0
    42         for idx, col 
in enumerate(goodch[0, 0, :, :].T):
    44                 goodch[0, 0, :, idx] = 1
    45         deadch = (~goodch).
float()
    46         goodch = goodch.float()
    49         predictgood = goodch * predict
    50         adcgood = goodch * true
    51         totnondead = goodch.sum().
float()
    54         nondeadloss = (L1loss(predictgood, adcgood)*nondeadweight)/totnondead
    57         deadchhighestcharge = deadch * (true.abs() > 70).
float()
    58         predictdeadhighestcharge = predict*deadchhighestcharge
    59         adcdeadhighestcharge = true*deadchhighestcharge
    60         totdeadhighestcharge = deadchhighestcharge.sum().
float()
    61         if (totdeadhighestcharge == 0):
    62             totdeadhighestcharge = 1.0
    63         deadhighestchargeloss = (L1loss(predictdeadhighestcharge,adcdeadhighestcharge)*
    64             deadhighestchargeweight)/totdeadhighestcharge
    67         deadchhighcharge = deadch * (true.abs() > 40).
float() * (true.abs() < 70).
float()
    68         predictdeadhighcharge = predict*deadchhighcharge
    69         adcdeadhighcharge = true*deadchhighcharge
    70         totdeadhighcharge = deadchhighcharge.sum().
float()
    71         if (totdeadhighcharge == 0):
    72             totdeadhighcharge = 1.0
    73         deadhighchargeloss = (L1loss(predictdeadhighcharge,adcdeadhighcharge)*
    74             deadhighchargeweight)/totdeadhighcharge
    77         deadchlowcharge = deadch * (true.abs() > 10).
float() * (true.abs() < 40).
float()
    78         deadchnocharge = deadch * (true.abs() < 10).
float()
    82         dead_cols = torch.nonzero(deadch)
    84             dead_cols = torch.unique(torch.nonzero(img[0, :, :])[:, 1])
    86                 diffs = img[0, :, col].roll(1) - img[0, :, col].roll(-1)
    89                 deadchlowcharge_hidden = (diffs.abs() > 20).
float() * deadchnocharge[i, 0, :, col].
float()
    90                 deadchnocharge[i, 0, :, col] -= deadchlowcharge_hidden
    91                 deadchlowcharge[i, 0, :, col] += deadchlowcharge_hidden
    93         predictdeadlowcharge = predict*deadchlowcharge
    94         adcdeadlowcharge = true*deadchlowcharge
    95         totdeadlowcharge = deadchlowcharge.sum().
float()
    96         if (totdeadlowcharge == 0):
    97             totdeadlowcharge = 1.0
    98         deadlowchargeloss = (L1loss(predictdeadlowcharge,adcdeadlowcharge)*deadlowchargeweight)/totdeadlowcharge
   100         predictdeadnocharge = predict*deadchnocharge
   101         adcdeadnocharge = true*deadchnocharge
   102         totdeadnocharge = deadchnocharge.sum().
float()
   103         if (totdeadnocharge == 0):
   104             totdeadnocharge = 1.0
   105         deadnochargeloss = (L1loss(predictdeadnocharge,adcdeadnocharge)*deadnochargeweight)/totdeadnocharge
   107         totloss = nondeadloss + deadnochargeloss + deadlowchargeloss +deadhighchargeloss+deadhighestchargeloss
   108         return nondeadloss, deadnochargeloss, deadlowchargeloss, deadhighchargeloss,deadhighestchargeloss, totloss
   112     def __init__(self, weight=None, size_average=False, ignore_index=-100):
   113         super(InfillLossCollection, self).
__init__(weight, size_average)
   124         deadnochargeweight = 500.0
   125         deadlowchargeweight = 100.0
   126         deadhighchargeweight = 100.0
   127         deadhighestchargeweight = 100.0
   130         goodch = (input != 0)
   131         for idx, col 
in enumerate(goodch[0, 0, :, :].T):
   133                 goodch[0, 0, :, idx] = 1
   134         deadch = (~goodch).
float()
   135         goodch = goodch.float()
   138         predictgood = goodch * predict
   139         adcgood = goodch * true
   140         totnondead = goodch.sum().
float()
   141         if (totnondead == 0):
   143         nondeadloss = (L1loss(predictgood, adcgood)*nondeadweight)/totnondead
   146         deadchhighestcharge = deadch * (true.abs() > 70).
float()
   147         predictdeadhighestcharge = predict*deadchhighestcharge
   148         adcdeadhighestcharge = true*deadchhighestcharge
   149         totdeadhighestcharge = deadchhighestcharge.sum().
float()
   150         if (totdeadhighestcharge == 0):
   151             totdeadhighestcharge = 1.0
   152         deadhighestchargeloss = (L1loss(predictdeadhighestcharge,adcdeadhighestcharge)*deadhighestchargeweight)/totdeadhighestcharge
   155         deadchhighcharge = deadch * (true.abs() > 40).
float()*(true.abs() < 70).
float()
   156         predictdeadhighcharge = predict*deadchhighcharge
   157         adcdeadhighcharge = true*deadchhighcharge
   158         totdeadhighcharge = deadchhighcharge.sum().
float()
   159         if (totdeadhighcharge == 0):
   160             totdeadhighcharge = 1.0
   161         deadhighchargeloss = (L1loss(predictdeadhighcharge,adcdeadhighcharge)*deadhighchargeweight)/totdeadhighcharge
   164         deadchlowcharge = deadch * (true.abs() > 10).
float() *(true.abs() < 40).
float()
   165         deadchnocharge = deadch * (true.abs() < 10).
float()
   166         predictdeadlowcharge = predict*deadchlowcharge
   167         adcdeadlowcharge = true*deadchlowcharge
   168         totdeadlowcharge = deadchlowcharge.sum().
float()
   169         if (totdeadlowcharge == 0):
   170             totdeadlowcharge = 1.0
   171         deadlowchargeloss = (L1loss(predictdeadlowcharge,adcdeadlowcharge)*deadlowchargeweight)/totdeadlowcharge
   174         predictdeadnocharge = predict*deadchnocharge
   175         adcdeadnocharge = true*deadchnocharge
   176         totdeadnocharge = deadchnocharge.sum().
float()
   177         if (totdeadnocharge == 0):
   178             totdeadnocharge = 1.0
   179         deadnochargeloss = (L1loss(predictdeadnocharge,adcdeadnocharge)*deadnochargeweight)/totdeadnocharge
   181         totloss = nondeadloss + deadnochargeloss + deadlowchargeloss +deadhighchargeloss+deadhighestchargeloss
   182         return nondeadloss, deadnochargeloss, deadlowchargeloss, deadhighchargeloss,deadhighestchargeloss, totloss
 def _assert_no_grad(variable)
def __init__(self, weight=None, size_average=False, ignore_index=-100)
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration. 
def forward(self, predict, true, input)
def forward(self, predict, true, input)
def __init__(self, weight=None, size_average=False, ignore_index=-100)