Classes | Public Member Functions | Static Public Member Functions | Private Member Functions | Private Attributes | List of all members
nnet::PointIdAlg Class Reference

#include <PointIdAlg.h>

Inheritance diagram for nnet::PointIdAlg:
img::DataProviderAlg

Classes

struct  Config
 

Public Member Functions

 PointIdAlg (const fhicl::ParameterSet &pset)
 
 PointIdAlg (const Config &config)
 
 ~PointIdAlg () override
 
std::vector< std::string > const & outputLabels () const
 network output labels More...
 
float predictIdValue (unsigned int wire, float drift, size_t outIdx=0) const
 calculate single-value prediction (2-class probability) for [wire, drift] point More...
 
std::vector< float > predictIdVector (unsigned int wire, float drift) const
 calculate multi-class probabilities for [wire, drift] point More...
 
std::vector< std::vector< float > > predictIdVectors (std::vector< std::pair< unsigned int, float >> points) const
 
std::vector< std::vector< float > > const & patchData2D () const
 
std::vector< float > patchData1D () const
 
bool isInsideFiducialRegion (unsigned int wire, float drift) const
 
bool isCurrentPatch (unsigned int wire, float drift) const
 
bool isSamePatch (unsigned int wire1, float drift1, unsigned int wire2, float drift2) const
 test if two wire/drift coordinates point to the same patch More...
 
- Public Member Functions inherited from img::DataProviderAlg
 DataProviderAlg (const fhicl::ParameterSet &pset)
 
 DataProviderAlg (const Config &config)
 
virtual ~DataProviderAlg ()
 
bool setWireDriftData (const detinfo::DetectorClocksData &clock_data, const detinfo::DetectorPropertiesData &det_prop, const std::vector< recob::Wire > &wires, unsigned int plane, unsigned int tpc, unsigned int cryo)
 
std::vector< float > const & wireData (size_t widx) const
 
std::vector< std::vector< float > > getPatch (size_t wire, float drift, size_t patchSizeW, size_t patchSizeD) const
 
float getPixelOrZero (int wire, int drift) const
 
double getAdcSum () const
 
size_t getAdcArea () const
 
float poolMax (int wire, int drift, size_t r=0) const
 Pool max value in a patch around the wire/drift pixel. More...
 
unsigned int Cryo () const
 Pool sum of pixels in a patch around the wire/drift pixel. More...
 
unsigned int TPC () const
 
unsigned int Plane () const
 
unsigned int NWires () const
 
unsigned int NScaledDrifts () const
 
unsigned int NCachedDrifts () const
 
unsigned int DriftWindow () const
 
float ZeroLevel () const
 Level of zero ADC after scaling. More...
 
double LifetimeCorrection (detinfo::DetectorClocksData const &clock_data, detinfo::DetectorPropertiesData const &det_prop, double tick) const
 

Static Public Member Functions

static std::vector< float > flattenData2D (std::vector< std::vector< float >> const &patch)
 

Private Member Functions

bool bufferPatch (size_t wire, float drift, std::vector< std::vector< float >> &patch) const
 
bool bufferPatch (size_t wire, float drift) const
 
void resizePatch ()
 
void deleteNNet ()
 

Private Attributes

std::string fNNetModelFilePath
 
std::vector< std::stringfNNetOutputs
 
nnet::ModelInterfacefNNet
 
std::vector< std::vector< float > > fWireDriftPatch
 
size_t fPatchSizeW
 
size_t fPatchSizeD
 
size_t fCurrentWireIdx
 
size_t fCurrentScaledDrift
 

Additional Inherited Members

- Public Types inherited from img::DataProviderAlg
enum  EDownscaleMode { kMax = 1, kMaxMean = 2, kMean = 3 }
 
- Protected Member Functions inherited from img::DataProviderAlg
std::vector< float > downscaleMax (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscaleMaxMean (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscaleMean (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscale (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
size_t getDriftIndex (float drift) const
 
std::optional< std::vector< float > > setWireData (std::vector< float > const &adc, size_t wireIdx) const
 
bool patchFromDownsampledView (size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
 
bool patchFromOriginalView (size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
 
virtual DataProviderAlgView resizeView (detinfo::DetectorClocksData const &clock_data, detinfo::DetectorPropertiesData const &det_prop, size_t wires, size_t drifts)
 
- Protected Attributes inherited from img::DataProviderAlg
DataProviderAlgView fAlgView
 
EDownscaleMode fDownscaleMode
 
size_t fDriftWindow
 
bool fDownscaleFullView
 
float fDriftWindowInv
 
calo::CalorimetryAlg fCalorimetryAlg
 
geo::GeometryCore const * fGeometry
 

Detailed Description

Definition at line 88 of file PointIdAlg.h.

Constructor & Destructor Documentation

nnet::PointIdAlg::PointIdAlg ( const fhicl::ParameterSet pset)
inline

Definition at line 104 of file PointIdAlg.h.

104 : PointIdAlg(fhicl::Table<Config>(pset, {})()) {}
PointIdAlg(const fhicl::ParameterSet &pset)
Definition: PointIdAlg.h:104
nnet::PointIdAlg::PointIdAlg ( const Config config)

Definition at line 152 of file PointIdAlg.cxx.

154  , fNNet(0)
157  , fCurrentWireIdx(99999)
158  , fCurrentScaledDrift(99999)
159 {
162 
163  deleteNNet();
164 
165  if ((fNNetModelFilePath.length() > 5) &&
166  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 5, 5, ".nnet") == 0)) {
168  }
169  else if ((fNNetModelFilePath.length() > 3) &&
170  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 3, 3, ".pb") == 0)) {
172  }
173  else {
174  mf::LogError("PointIdAlg") << "File name extension not supported.";
175  }
176 
177  if (!fNNet) { throw cet::exception("nnet::PointIdAlg") << "Loading model from file failed."; }
178 
179  resizePatch();
180 }
fhicl::Atom< unsigned int > PatchSizeW
Definition: PointIdAlg.h:98
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::vector< std::string > fNNetOutputs
Definition: PointIdAlg.h:150
fhicl::Atom< unsigned int > PatchSizeD
Definition: PointIdAlg.h:100
fhicl::Sequence< std::string > NNetOutputs
Definition: PointIdAlg.h:96
static Config * config
Definition: config.cpp:1054
size_t fPatchSizeD
Definition: PointIdAlg.h:154
fhicl::Atom< std::string > NNetModelFile
Definition: PointIdAlg.h:94
size_t fCurrentScaledDrift
Definition: PointIdAlg.h:156
nnet::ModelInterface * fNNet
Definition: PointIdAlg.h:151
size_t fPatchSizeW
Definition: PointIdAlg.h:154
std::string fNNetModelFilePath
Definition: PointIdAlg.h:149
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
size_t fCurrentWireIdx
Definition: PointIdAlg.h:156
nnet::PointIdAlg::~PointIdAlg ( )
override

Definition at line 183 of file PointIdAlg.cxx.

184 {
185  deleteNNet();
186 }

Member Function Documentation

bool nnet::PointIdAlg::bufferPatch ( size_t  wire,
float  drift,
std::vector< std::vector< float >> &  patch 
) const
inlineprivate

Definition at line 158 of file PointIdAlg.h.

159  {
160  if (fDownscaleFullView) {
161  size_t sd = (size_t)(drift / fDriftWindow);
162  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == sd))
163  return true; // still within the current position
164 
165  fCurrentWireIdx = wire;
166  fCurrentScaledDrift = sd;
167 
168  return patchFromDownsampledView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
169  }
170  else {
171  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == drift))
172  return true; // still within the current position
173 
174  fCurrentWireIdx = wire;
175  fCurrentScaledDrift = drift;
176 
177  return patchFromOriginalView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
178  }
179  }
size_t fPatchSizeD
Definition: PointIdAlg.h:154
bool patchFromOriginalView(size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
size_t fCurrentScaledDrift
Definition: PointIdAlg.h:156
bool patchFromDownsampledView(size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
size_t fPatchSizeW
Definition: PointIdAlg.h:154
size_t fCurrentWireIdx
Definition: PointIdAlg.h:156
bool nnet::PointIdAlg::bufferPatch ( size_t  wire,
float  drift 
) const
inlineprivate

Definition at line 181 of file PointIdAlg.h.

182  {
183  return bufferPatch(wire, drift, fWireDriftPatch);
184  }
std::vector< std::vector< float > > fWireDriftPatch
Definition: PointIdAlg.h:153
bool bufferPatch(size_t wire, float drift, std::vector< std::vector< float >> &patch) const
Definition: PointIdAlg.h:158
void nnet::PointIdAlg::deleteNNet ( )
inlineprivate

Definition at line 188 of file PointIdAlg.h.

189  {
190  if (fNNet) delete fNNet;
191  fNNet = 0;
192  }
nnet::ModelInterface * fNNet
Definition: PointIdAlg.h:151
std::vector< float > nnet::PointIdAlg::flattenData2D ( std::vector< std::vector< float >> const &  patch)
static

Definition at line 294 of file PointIdAlg.cxx.

295 {
296  std::vector<float> flat;
297  if (patch.empty() || patch.front().empty()) {
298  mf::LogError("DataProviderAlg") << "Patch is empty.";
299  return flat;
300  }
301 
302  flat.resize(patch.size() * patch.front().size());
303 
304  for (size_t w = 0, i = 0; w < patch.size(); ++w) {
305  auto const& wire = patch[w];
306  for (size_t d = 0; d < wire.size(); ++d, ++i) {
307  flat[i] = wire[d];
308  }
309  }
310 
311  return flat;
312 }
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
bool nnet::PointIdAlg::isCurrentPatch ( unsigned int  wire,
float  drift 
) const

test if wire/drift coordinates point to the current patch (so maybe the cnn output does not need to be recalculated)

Definition at line 277 of file PointIdAlg.cxx.

278 {
279  if (fDownscaleFullView) {
280  size_t sd = (size_t)(drift / fDriftWindow);
281  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == sd))
282  return true; // still within the current position
283  }
284  else {
285  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == drift))
286  return true; // still within the current position
287  }
288 
289  return false; // not a current position
290 }
size_t fCurrentScaledDrift
Definition: PointIdAlg.h:156
size_t fCurrentWireIdx
Definition: PointIdAlg.h:156
bool nnet::PointIdAlg::isInsideFiducialRegion ( unsigned int  wire,
float  drift 
) const

Definition at line 316 of file PointIdAlg.cxx.

317 {
318  size_t marginW = fPatchSizeW / 8; // fPatchSizeX/2 will make patch always completely filled
319  size_t marginD = fPatchSizeD / 8;
320 
321  size_t scaledDrift = (size_t)(drift / fDriftWindow);
322  if ((wire >= marginW) && (wire < fAlgView.fNWires - marginW) && (scaledDrift >= marginD) &&
323  (scaledDrift < fAlgView.fNScaledDrifts - marginD))
324  return true;
325  else
326  return false;
327 }
size_t fPatchSizeD
Definition: PointIdAlg.h:154
DataProviderAlgView fAlgView
size_t fPatchSizeW
Definition: PointIdAlg.h:154
bool nnet::PointIdAlg::isSamePatch ( unsigned int  wire1,
float  drift1,
unsigned int  wire2,
float  drift2 
) const

test if two wire/drift coordinates point to the same patch

Definition at line 259 of file PointIdAlg.cxx.

263 {
264  if (fDownscaleFullView) {
265  size_t sd1 = (size_t)(drift1 / fDriftWindow);
266  size_t sd2 = (size_t)(drift2 / fDriftWindow);
267  if ((wire1 == wire2) && (sd1 == sd2)) return true; // the same position
268  }
269  else {
270  if ((wire1 == wire2) && ((size_t)drift1 == (size_t)drift2)) return true; // the same position
271  }
272 
273  return false; // not the same position
274 }
std::vector<std::string> const& nnet::PointIdAlg::outputLabels ( ) const
inline

network output labels

Definition at line 112 of file PointIdAlg.h.

113  {
114  return fNNetOutputs;
115  }
std::vector< std::string > fNNetOutputs
Definition: PointIdAlg.h:150
std::vector<float> nnet::PointIdAlg::patchData1D ( ) const
inline

Definition at line 134 of file PointIdAlg.h.

135  {
137  } // flat vector made of the patch data, wire after wire
std::vector< std::vector< float > > fWireDriftPatch
Definition: PointIdAlg.h:153
static std::vector< float > flattenData2D(std::vector< std::vector< float >> const &patch)
Definition: PointIdAlg.cxx:294
std::vector<std::vector<float> > const& nnet::PointIdAlg::patchData2D ( ) const
inline

Definition at line 129 of file PointIdAlg.h.

130  {
131  return fWireDriftPatch;
132  }
std::vector< std::vector< float > > fWireDriftPatch
Definition: PointIdAlg.h:153
float nnet::PointIdAlg::predictIdValue ( unsigned int  wire,
float  drift,
size_t  outIdx = 0 
) const

calculate single-value prediction (2-class probability) for [wire, drift] point

Definition at line 199 of file PointIdAlg.cxx.

200 {
201  float result = 0.;
202 
203  if (!bufferPatch(wire, drift)) {
204  mf::LogError("PointIdAlg") << "Patch buffering failed.";
205  return result;
206  }
207 
208  if (fNNet) {
209  auto out = fNNet->Run(fWireDriftPatch);
210  if (!out.empty()) { result = out[outIdx]; }
211  else {
212  mf::LogError("PointIdAlg") << "Problem with applying model to input.";
213  }
214  }
215 
216  return result;
217 }
static QCString result
std::vector< std::vector< float > > fWireDriftPatch
Definition: PointIdAlg.h:153
bool bufferPatch(size_t wire, float drift, std::vector< std::vector< float >> &patch) const
Definition: PointIdAlg.h:158
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
virtual std::vector< float > Run(std::vector< std::vector< float >> const &inp2d)=0
nnet::ModelInterface * fNNet
Definition: PointIdAlg.h:151
std::vector< float > nnet::PointIdAlg::predictIdVector ( unsigned int  wire,
float  drift 
) const

calculate multi-class probabilities for [wire, drift] point

Definition at line 221 of file PointIdAlg.cxx.

222 {
223  std::vector<float> result;
224 
225  if (!bufferPatch(wire, drift)) {
226  mf::LogError("PointIdAlg") << "Patch buffering failed.";
227  return result;
228  }
229 
230  if (fNNet) {
231  result = fNNet->Run(fWireDriftPatch);
232  if (result.empty()) { mf::LogError("PointIdAlg") << "Problem with applying model to input."; }
233  }
234 
235  return result;
236 }
static QCString result
std::vector< std::vector< float > > fWireDriftPatch
Definition: PointIdAlg.h:153
bool bufferPatch(size_t wire, float drift, std::vector< std::vector< float >> &patch) const
Definition: PointIdAlg.h:158
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
virtual std::vector< float > Run(std::vector< std::vector< float >> const &inp2d)=0
nnet::ModelInterface * fNNet
Definition: PointIdAlg.h:151
std::vector< std::vector< float > > nnet::PointIdAlg::predictIdVectors ( std::vector< std::pair< unsigned int, float >>  points) const

Definition at line 240 of file PointIdAlg.cxx.

241 {
242  if (points.empty() || !fNNet) { return std::vector<std::vector<float>>(); }
243 
244  std::vector<std::vector<std::vector<float>>> inps(
245  points.size(), std::vector<std::vector<float>>(fPatchSizeW, std::vector<float>(fPatchSizeD)));
246  for (size_t i = 0; i < points.size(); ++i) {
247  unsigned int wire = points[i].first;
248  float drift = points[i].second;
249  if (!bufferPatch(wire, drift, inps[i])) {
250  throw cet::exception("PointIdAlg") << "Patch buffering failed" << std::endl;
251  }
252  }
253 
254  return fNNet->Run(inps);
255 }
bool bufferPatch(size_t wire, float drift, std::vector< std::vector< float >> &patch) const
Definition: PointIdAlg.h:158
virtual std::vector< float > Run(std::vector< std::vector< float >> const &inp2d)=0
size_t fPatchSizeD
Definition: PointIdAlg.h:154
nnet::ModelInterface * fNNet
Definition: PointIdAlg.h:151
size_t fPatchSizeW
Definition: PointIdAlg.h:154
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)
void nnet::PointIdAlg::resizePatch ( )
private

Definition at line 190 of file PointIdAlg.cxx.

191 {
193  for (auto& r : fWireDriftPatch)
194  r.resize(fPatchSizeD);
195 }
std::vector< std::vector< float > > fWireDriftPatch
Definition: PointIdAlg.h:153
size_t fPatchSizeD
Definition: PointIdAlg.h:154
size_t fPatchSizeW
Definition: PointIdAlg.h:154

Member Data Documentation

size_t nnet::PointIdAlg::fCurrentScaledDrift
mutableprivate

Definition at line 156 of file PointIdAlg.h.

size_t nnet::PointIdAlg::fCurrentWireIdx
mutableprivate

Definition at line 156 of file PointIdAlg.h.

nnet::ModelInterface* nnet::PointIdAlg::fNNet
private

Definition at line 151 of file PointIdAlg.h.

std::string nnet::PointIdAlg::fNNetModelFilePath
private

Definition at line 149 of file PointIdAlg.h.

std::vector<std::string> nnet::PointIdAlg::fNNetOutputs
private

Definition at line 150 of file PointIdAlg.h.

size_t nnet::PointIdAlg::fPatchSizeD
private

Definition at line 154 of file PointIdAlg.h.

size_t nnet::PointIdAlg::fPatchSizeW
private

Definition at line 154 of file PointIdAlg.h.

std::vector<std::vector<float> > nnet::PointIdAlg::fWireDriftPatch
mutableprivate

Definition at line 153 of file PointIdAlg.h.


The documentation for this class was generated from the following files: