Public Member Functions | Public Attributes | List of all members
infill_loss.InfillLossInduction Class Reference
Inheritance diagram for infill_loss.InfillLossInduction:

Public Member Functions

def __init__ (self, weight=None, size_average=False, ignore_index=-100)
 
def forward (self, predict, true, input)
 

Public Attributes

 ignore_index
 
 reduce
 
 size_average
 

Detailed Description

Definition at line 22 of file infill_loss.py.

Constructor & Destructor Documentation

def infill_loss.InfillLossInduction.__init__ (   self,
  weight = None,
  size_average = False,
  ignore_index = -100 
)

Definition at line 23 of file infill_loss.py.

23  def __init__(self, weight=None, size_average=False, ignore_index=-100):
24  super(InfillLossInduction, self).__init__(weight, size_average)
25  self.ignore_index = ignore_index
26  self.reduce = False
27  self.size_average = size_average
28 
def __init__(self, weight=None, size_average=False, ignore_index=-100)
Definition: infill_loss.py:23

Member Function Documentation

def infill_loss.InfillLossInduction.forward (   self,
  predict,
  true,
  input 
)

Definition at line 29 of file infill_loss.py.

29  def forward(self, predict, true, input):
30  _assert_no_grad(true)
31 
32  # Want three losses: non-dead, dead w/o charge, dead w/ charge
33  L1loss=torch.nn.L1Loss(self.size_average)
34  nondeadweight = 1.0
35  deadnochargeweight = 500.0
36  deadlowchargeweight = 100.0
37  deadhighchargeweight = 100.0
38  deadhighestchargeweight = 100.0
39 
40  # Identify dead channels
41  goodch = (input != 0)
42  for idx, col in enumerate(goodch[0, 0, :, :].T):
43  if col.sum():
44  goodch[0, 0, :, idx] = 1
45  deadch = (~goodch).float()
46  goodch = goodch.float()
47 
48  # Compute non-dead loss
49  predictgood = goodch * predict
50  adcgood = goodch * true
51  totnondead = goodch.sum().float()
52  if (totnondead == 0):
53  totnondead = 1.0
54  nondeadloss = (L1loss(predictgood, adcgood)*nondeadweight)/totnondead
55 
56  # Compute dead with highest true charge loss
57  deadchhighestcharge = deadch * (true.abs() > 70).float()
58  predictdeadhighestcharge = predict*deadchhighestcharge
59  adcdeadhighestcharge = true*deadchhighestcharge
60  totdeadhighestcharge = deadchhighestcharge.sum().float()
61  if (totdeadhighestcharge == 0):
62  totdeadhighestcharge = 1.0
63  deadhighestchargeloss = (L1loss(predictdeadhighestcharge,adcdeadhighestcharge)*
64  deadhighestchargeweight)/totdeadhighestcharge
65 
66  # Compute dead with high true charge loss
67  deadchhighcharge = deadch * (true.abs() > 40).float() * (true.abs() < 70).float()
68  predictdeadhighcharge = predict*deadchhighcharge
69  adcdeadhighcharge = true*deadchhighcharge
70  totdeadhighcharge = deadchhighcharge.sum().float()
71  if (totdeadhighcharge == 0):
72  totdeadhighcharge = 1.0
73  deadhighchargeloss = (L1loss(predictdeadhighcharge,adcdeadhighcharge)*
74  deadhighchargeweight)/totdeadhighcharge
75 
76  # Compute dead with low true chanrge and dead with no (< 10) true charge
77  deadchlowcharge = deadch * (true.abs() > 10).float() * (true.abs() < 40).float()
78  deadchnocharge = deadch * (true.abs() < 10).float()
79 
80  # < 10 charge regions where the bipolar signal crosses zero should be low true charge rather than dead.
81  # Identify them and move them from deadnocharge to deadlowcharge
82  dead_cols = torch.nonzero(deadch)
83  for i, img in enumerate(true * deadch):
84  dead_cols = torch.unique(torch.nonzero(img[0, :, :])[:, 1])
85  for col in dead_cols:
86  diffs = img[0, :, col].roll(1) - img[0, :, col].roll(-1)
87  diffs[0] = 0
88  diffs[-1] = 0
89  deadchlowcharge_hidden = (diffs.abs() > 20).float() * deadchnocharge[i, 0, :, col].float()
90  deadchnocharge[i, 0, :, col] -= deadchlowcharge_hidden
91  deadchlowcharge[i, 0, :, col] += deadchlowcharge_hidden
92 
93  predictdeadlowcharge = predict*deadchlowcharge
94  adcdeadlowcharge = true*deadchlowcharge
95  totdeadlowcharge = deadchlowcharge.sum().float()
96  if (totdeadlowcharge == 0):
97  totdeadlowcharge = 1.0
98  deadlowchargeloss = (L1loss(predictdeadlowcharge,adcdeadlowcharge)*deadlowchargeweight)/totdeadlowcharge
99 
100  predictdeadnocharge = predict*deadchnocharge
101  adcdeadnocharge = true*deadchnocharge
102  totdeadnocharge = deadchnocharge.sum().float()
103  if (totdeadnocharge == 0):
104  totdeadnocharge = 1.0
105  deadnochargeloss = (L1loss(predictdeadnocharge,adcdeadnocharge)*deadnochargeweight)/totdeadnocharge
106 
107  totloss = nondeadloss + deadnochargeloss + deadlowchargeloss +deadhighchargeloss+deadhighestchargeloss
108  return nondeadloss, deadnochargeloss, deadlowchargeloss, deadhighchargeloss,deadhighestchargeloss, totloss
109 
110 
def _assert_no_grad(variable)
Definition: infill_loss.py:16
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def forward(self, predict, true, input)
Definition: infill_loss.py:29

Member Data Documentation

infill_loss.InfillLossInduction.ignore_index

Definition at line 25 of file infill_loss.py.

infill_loss.InfillLossInduction.reduce

Definition at line 26 of file infill_loss.py.

infill_loss.InfillLossInduction.size_average

Definition at line 27 of file infill_loss.py.


The documentation for this class was generated from the following file: