Public Member Functions | Public Attributes | List of all members
infill_loss.InfillLossCollection Class Reference
Inheritance diagram for infill_loss.InfillLossCollection:

Public Member Functions

def __init__ (self, weight=None, size_average=False, ignore_index=-100)
 
def forward (self, predict, true, input)
 

Public Attributes

 ignore_index
 
 reduce
 
 size_average
 

Detailed Description

Definition at line 111 of file infill_loss.py.

Constructor & Destructor Documentation

def infill_loss.InfillLossCollection.__init__ (   self,
  weight = None,
  size_average = False,
  ignore_index = -100 
)

Definition at line 112 of file infill_loss.py.

112  def __init__(self, weight=None, size_average=False, ignore_index=-100):
113  super(InfillLossCollection, self).__init__(weight, size_average)
114  self.ignore_index = ignore_index
115  self.reduce = False
116  self.size_average = size_average
117 
def __init__(self, weight=None, size_average=False, ignore_index=-100)
Definition: infill_loss.py:112

Member Function Documentation

def infill_loss.InfillLossCollection.forward (   self,
  predict,
  true,
  input 
)

Definition at line 118 of file infill_loss.py.

118  def forward(self, predict, true, input):
119  _assert_no_grad(true)
120 
121  # Want three losses: non-dead, dead w/o charge, dead w/ charge
122  L1loss=torch.nn.L1Loss(self.size_average)
123  nondeadweight = 1.0
124  deadnochargeweight = 500.0
125  deadlowchargeweight = 100.0
126  deadhighchargeweight = 100.0
127  deadhighestchargeweight = 100.0
128 
129  # Identify dead channels
130  goodch = (input != 0)
131  for idx, col in enumerate(goodch[0, 0, :, :].T):
132  if col.sum():
133  goodch[0, 0, :, idx] = 1
134  deadch = (~goodch).float()
135  goodch = goodch.float()
136 
137  # Compute non-dead loss
138  predictgood = goodch * predict
139  adcgood = goodch * true
140  totnondead = goodch.sum().float()
141  if (totnondead == 0):
142  totnondead = 1.0
143  nondeadloss = (L1loss(predictgood, adcgood)*nondeadweight)/totnondead
144 
145  # Compute dead with highest true charge loss
146  deadchhighestcharge = deadch * (true.abs() > 70).float()
147  predictdeadhighestcharge = predict*deadchhighestcharge
148  adcdeadhighestcharge = true*deadchhighestcharge
149  totdeadhighestcharge = deadchhighestcharge.sum().float()
150  if (totdeadhighestcharge == 0):
151  totdeadhighestcharge = 1.0
152  deadhighestchargeloss = (L1loss(predictdeadhighestcharge,adcdeadhighestcharge)*deadhighestchargeweight)/totdeadhighestcharge
153 
154  # Compute dead with high true charge loss
155  deadchhighcharge = deadch * (true.abs() > 40).float()*(true.abs() < 70).float()
156  predictdeadhighcharge = predict*deadchhighcharge
157  adcdeadhighcharge = true*deadchhighcharge
158  totdeadhighcharge = deadchhighcharge.sum().float()
159  if (totdeadhighcharge == 0):
160  totdeadhighcharge = 1.0
161  deadhighchargeloss = (L1loss(predictdeadhighcharge,adcdeadhighcharge)*deadhighchargeweight)/totdeadhighcharge
162 
163  # Compute dead with low true charge loss
164  deadchlowcharge = deadch * (true.abs() > 10).float() *(true.abs() < 40).float()
165  deadchnocharge = deadch * (true.abs() < 10).float()
166  predictdeadlowcharge = predict*deadchlowcharge
167  adcdeadlowcharge = true*deadchlowcharge
168  totdeadlowcharge = deadchlowcharge.sum().float()
169  if (totdeadlowcharge == 0):
170  totdeadlowcharge = 1.0
171  deadlowchargeloss = (L1loss(predictdeadlowcharge,adcdeadlowcharge)*deadlowchargeweight)/totdeadlowcharge
172 
173  # Compute dead with no (< 10) true charge loss
174  predictdeadnocharge = predict*deadchnocharge
175  adcdeadnocharge = true*deadchnocharge
176  totdeadnocharge = deadchnocharge.sum().float()
177  if (totdeadnocharge == 0):
178  totdeadnocharge = 1.0
179  deadnochargeloss = (L1loss(predictdeadnocharge,adcdeadnocharge)*deadnochargeweight)/totdeadnocharge
180 
181  totloss = nondeadloss + deadnochargeloss + deadlowchargeloss +deadhighchargeloss+deadhighestchargeloss
182  return nondeadloss, deadnochargeloss, deadlowchargeloss, deadhighchargeloss,deadhighestchargeloss, totloss
def _assert_no_grad(variable)
Definition: infill_loss.py:16
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def forward(self, predict, true, input)
Definition: infill_loss.py:118

Member Data Documentation

infill_loss.InfillLossCollection.ignore_index

Definition at line 114 of file infill_loss.py.

infill_loss.InfillLossCollection.reduce

Definition at line 115 of file infill_loss.py.

infill_loss.InfillLossCollection.size_average

Definition at line 116 of file infill_loss.py.


The documentation for this class was generated from the following file: