PointIdAlg.h
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 // Class: PointIdAlg
3 // Authors: D.Stefan (Dorota.Stefan@ncbj.gov.pl), from DUNE, CERN/NCBJ, since May 2016
4 // R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, since May 2016
5 // P.Plonski, from DUNE, WUT, since May 2016
6 //
7 //
8 // Point Identification Algorithm
9 //
10 // Run CNN or MLP trained to classify a point in 2D projection. Various features can be
11 // recognized, depending on the net model/weights used.
12 //
13 ////////////////////////////////////////////////////////////////////////////////////////////////////
14 
15 #ifndef PointIdAlg_h
16 #define PointIdAlg_h
17 
18 // Framework includes
20 #include "canvas/Persistency/Common/FindManyP.h"
22 
23 // LArSoft includes
31 namespace detinfo {
32  class DetectorClocksData;
33  class DetectorPropertiesData;
34 }
35 
36 // ROOT & C++
37 #include <memory>
38 
39 namespace nnet {
40  class ModelInterface;
41  class KerasModelInterface;
42  class TfModelInterface;
43  class PointIdAlg;
44  class TrainingDataAlg;
45 }
46 
47 /// Interface class for various classifier models. Now MLP (NetMaker) and CNN (Keras with
48 /// simple cpp interface) are supported. Will add interface to Protobuf as soon as Tensorflow
49 /// may be used from UPS.
51 public:
52  virtual ~ModelInterface() {}
53 
54  virtual std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) = 0;
55  virtual std::vector<std::vector<float>> Run(
56  std::vector<std::vector<std::vector<float>>> const& inps,
57  int samples = -1);
58 
59 protected:
60  std::string findFile(const char* fileName) const;
61 };
62 // ------------------------------------------------------
63 
65 public:
66  KerasModelInterface(const char* modelFileName);
67 
68  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) override;
69 
70 private:
71  keras::KerasModel m; // network model
72 };
73 // ------------------------------------------------------
74 
76 public:
77  TfModelInterface(const char* modelFileName);
78 
79  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
80  int samples = -1) override;
81  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) override;
82 
83 private:
84  std::unique_ptr<tf::Graph> g; // network graph
85 };
86 // ------------------------------------------------------
87 
89 public:
91  using Name = fhicl::Name;
93 
94  fhicl::Atom<std::string> NNetModelFile{Name("NNetModelFile"),
95  Comment("Neural net model to apply.")};
96  fhicl::Sequence<std::string> NNetOutputs{Name("NNetOutputs"),
97  Comment("Labels of the network outputs.")};
98  fhicl::Atom<unsigned int> PatchSizeW{Name("PatchSizeW"), Comment("How many wires in patch.")};
99 
100  fhicl::Atom<unsigned int> PatchSizeD{Name("PatchSizeD"),
101  Comment("How many downsampled ADC entries in patch")};
102  };
103 
104  PointIdAlg(const fhicl::ParameterSet& pset) : PointIdAlg(fhicl::Table<Config>(pset, {})()) {}
105 
106  PointIdAlg(const Config& config);
107 
108  ~PointIdAlg() override;
109 
110  /// network output labels
111  std::vector<std::string> const&
112  outputLabels() const
113  {
114  return fNNetOutputs;
115  }
116 
117  /// calculate single-value prediction (2-class probability) for [wire, drift] point
118  float predictIdValue(unsigned int wire, float drift, size_t outIdx = 0) const;
119 
120  /// calculate multi-class probabilities for [wire, drift] point
121  std::vector<float> predictIdVector(unsigned int wire, float drift) const;
122 
123  std::vector<std::vector<float>> predictIdVectors(
124  std::vector<std::pair<unsigned int, float>> points) const;
125 
126  static std::vector<float> flattenData2D(std::vector<std::vector<float>> const& patch);
127 
128  std::vector<std::vector<float>> const&
129  patchData2D() const
130  {
131  return fWireDriftPatch;
132  }
133  std::vector<float>
134  patchData1D() const
135  {
136  return flattenData2D(fWireDriftPatch);
137  } // flat vector made of the patch data, wire after wire
138 
139  bool isInsideFiducialRegion(unsigned int wire, float drift) const;
140 
141  /// test if wire/drift coordinates point to the current patch (so maybe the cnn output
142  /// does not need to be recalculated)
143  bool isCurrentPatch(unsigned int wire, float drift) const;
144 
145  /// test if two wire/drift coordinates point to the same patch
146  bool isSamePatch(unsigned int wire1, float drift1, unsigned int wire2, float drift2) const;
147 
148 private:
150  std::vector<std::string> fNNetOutputs;
152 
153  mutable std::vector<std::vector<float>> fWireDriftPatch; // patch data around the identified point
154  size_t fPatchSizeW, fPatchSizeD;
155 
156  mutable size_t fCurrentWireIdx, fCurrentScaledDrift;
157  bool
158  bufferPatch(size_t wire, float drift, std::vector<std::vector<float>>& patch) const
159  {
160  if (fDownscaleFullView) {
161  size_t sd = (size_t)(drift / fDriftWindow);
162  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == sd))
163  return true; // still within the current position
164 
165  fCurrentWireIdx = wire;
166  fCurrentScaledDrift = sd;
167 
168  return patchFromDownsampledView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
169  }
170  else {
171  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == drift))
172  return true; // still within the current position
173 
174  fCurrentWireIdx = wire;
175  fCurrentScaledDrift = drift;
176 
177  return patchFromOriginalView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
178  }
179  }
180  bool
181  bufferPatch(size_t wire, float drift) const
182  {
183  return bufferPatch(wire, drift, fWireDriftPatch);
184  }
185  void resizePatch();
186 
187  void
189  {
190  if (fNNet) delete fNNet;
191  fNNet = 0;
192  }
193 };
194 // ------------------------------------------------------
195 // ------------------------------------------------------
196 // ------------------------------------------------------
197 
199 public:
200  enum EMask {
201  kNone = 0,
202  kPdgMask = 0x00000FFF, // pdg code mask
203  kTypeMask = 0x0000F000, // track type mask
204  kVtxMask = 0xFFFF0000 // vertex flags
205  };
206 
207  enum ETrkType {
208  kDelta = 0x1000, // delta electron
209  kMichel = 0x2000, // Michel electron
210  kPriEl = 0x4000, // primary electron
211  kPriMu = 0x8000 // primary muon
212  };
213 
214  enum EVtxId {
215  kNuNC = 0x0010000,
216  kNuCC = 0x0020000,
217  kNuPri = 0x0040000, // nu interaction type
218  kNuE = 0x0100000,
219  kNuMu = 0x0200000,
220  kNuTau = 0x0400000, // nu flavor
221  kHadr = 0x1000000, // hadronic inelastic scattering
222  kPi0 = 0x2000000, // pi0 produced in this vertex
223  kDecay = 0x4000000, // point of particle decay
224  kConv = 0x8000000, // gamma conversion
225  kElectronEnd = 0x10000000, // clear end of an electron
226  kElastic = 0x20000000, // Elastic scattering
227  kInelastic = 0x40000000 // Inelastic scattering
228  };
229 
231  using Name = fhicl::Name;
233 
234  fhicl::Atom<art::InputTag> WireLabel{Name("WireLabel"), Comment("Tag of recob::Wire.")};
235 
236  fhicl::Atom<art::InputTag> HitLabel{Name("HitLabel"), Comment("Tag of recob::Hit.")};
237 
238  fhicl::Atom<art::InputTag> TrackLabel{Name("TrackLabel"), Comment("Tag of recob::Track.")};
239 
240  fhicl::Atom<art::InputTag> SimulationLabel{Name("SimulationLabel"),
241  Comment("Tag of simulation producer.")};
242 
243  fhicl::Atom<art::InputTag> SimChannelLabel{Name("SimChannelLabel"),
244  Comment("Tag of sim::SimChannel producer.")};
245 
246  fhicl::Atom<bool> SaveVtxFlags{Name("SaveVtxFlags"),
247  Comment("Include (or not) vertex info in PDG map.")};
248 
250  Name("AdcDelayTicks"),
251  Comment("ADC pulse peak delay in ticks (non-zero for not deconvoluted waveforms).")};
252  };
253 
255  : TrainingDataAlg(fhicl::Table<Config>(pset, {})())
256  {}
257 
258  TrainingDataAlg(const Config& config);
259 
260  ~TrainingDataAlg() override;
261 
262  void reconfigure(const Config& config);
263 
264  bool
265  saveSimInfo() const
266  {
267  return fSaveSimInfo;
268  }
269 
270  bool setEventData(
271  const art::Event& event, // collect & downscale ADC's, charge deposits, pdg labels
272  detinfo::DetectorClocksData const& clockData,
273  detinfo::DetectorPropertiesData const& detProp,
274  unsigned int plane,
275  unsigned int tpc,
276  unsigned int cryo);
277 
278  bool setDataEventData(
279  const art::Event& event, // collect & downscale ADC's, charge deposits, pdg labels
280  detinfo::DetectorClocksData const& clockData,
281  detinfo::DetectorPropertiesData const& detProp,
282  unsigned int plane,
283  unsigned int tpc,
284  unsigned int cryo);
285 
286  bool findCrop(float max_e_cut,
287  unsigned int& w0,
288  unsigned int& w1,
289  unsigned int& d0,
290  unsigned int& d1) const;
291 
292  double
293  getEdepTot() const
294  {
295  return fEdepTot;
296  } // [GeV]
297  std::vector<float> const&
298  wireEdep(size_t widx) const
299  {
300  return fWireDriftEdep[widx];
301  }
302  std::vector<int> const&
303  wirePdg(size_t widx) const
304  {
305  return fWireDriftPdg[widx];
306  }
307 
308 protected:
309  img::DataProviderAlgView resizeView(detinfo::DetectorClocksData const& clock_data,
310  detinfo::DetectorPropertiesData const& det_prop,
311  size_t wires,
312  size_t drifts) override;
313 
314 private:
315  struct WireDrift // used to find MCParticle start/end 2D projections
316  {
317  size_t Wire;
318  int Drift;
319  unsigned int TPC;
320  unsigned int Cryo;
321  };
322 
323  WireDrift getProjection(detinfo::DetectorClocksData const& clockData,
324  detinfo::DetectorPropertiesData const& detProp,
325  const TLorentzVector& tvec,
326  unsigned int plane) const;
327 
328  bool setWireEdepsAndLabels(std::vector<float> const& edeps,
329  std::vector<int> const& pdgs,
330  size_t wireIdx);
331 
332  void collectVtxFlags(
333  std::unordered_map<size_t, std::unordered_map<int, int>>& wireToDriftToVtxFlags,
334  detinfo::DetectorClocksData const& clockData,
335  detinfo::DetectorPropertiesData const& detProp,
336  const std::unordered_map<int, const simb::MCParticle*>& particleMap,
337  unsigned int plane) const;
338 
339  static float
341  {
342  float dx = particle.EndX() - particle.Vx();
343  float dy = particle.EndY() - particle.Vy();
344  float dz = particle.EndZ() - particle.Vz();
345  return dx * dx + dy * dy + dz * dz;
346  }
347  bool isElectronEnd(const simb::MCParticle& particle,
348  const std::unordered_map<int, const simb::MCParticle*>& particleMap) const;
349 
350  bool isMuonDecaying(const simb::MCParticle& particle,
351  const std::unordered_map<int, const simb::MCParticle*>& particleMap) const;
352 
353  double fEdepTot; // [GeV]
354  std::vector<std::vector<float>> fWireDriftEdep;
355  std::vector<std::vector<int>> fWireDriftPdg;
356 
364 
365  unsigned int fAdcDelay;
366 
367  std::vector<size_t> fEventsPerBin;
368 };
369 // ------------------------------------------------------
370 // ------------------------------------------------------
371 // ------------------------------------------------------
372 
373 #endif
std::vector< std::string > const & outputLabels() const
network output labels
Definition: PointIdAlg.h:112
std::vector< float > const & wireEdep(size_t widx) const
Definition: PointIdAlg.h:298
double EndZ() const
Definition: MCParticle.h:228
std::vector< std::vector< float > > fWireDriftPatch
Definition: PointIdAlg.h:153
bool saveSimInfo() const
Definition: PointIdAlg.h:265
unsigned int fAdcDelay
Definition: PointIdAlg.h:365
std::string string
Definition: nybbler.cc:12
bool bufferPatch(size_t wire, float drift) const
Definition: PointIdAlg.h:181
art::InputTag fTrackModuleLabel
Definition: PointIdAlg.h:359
bool bufferPatch(size_t wire, float drift, std::vector< std::vector< float >> &patch) const
Definition: PointIdAlg.h:158
struct vector vector
ChannelGroupService::Name Name
Particle class.
double EndY() const
Definition: MCParticle.h:227
double getEdepTot() const
Definition: PointIdAlg.h:293
std::vector< std::string > fNNetOutputs
Definition: PointIdAlg.h:150
std::vector< float > patchData1D() const
Definition: PointIdAlg.h:134
virtual ~ModelInterface()
Definition: PointIdAlg.h:52
static FILE * findFile(const char *fileName)
Definition: config.cpp:1144
art::InputTag fWireProducerLabel
Definition: PointIdAlg.h:357
std::vector< std::vector< float > > const & patchData2D() const
Definition: PointIdAlg.h:129
fileName
Definition: dumpTree.py:9
virtual void reconfigure(fhicl::ParameterSet const &pset)
static Config * config
Definition: config.cpp:1054
std::vector< size_t > fEventsPerBin
Definition: PointIdAlg.h:367
art::InputTag fSimChannelProducerLabel
Definition: PointIdAlg.h:361
General LArSoft Utilities.
keras::KerasModel m
Definition: PointIdAlg.h:71
no compression
Definition: RawTypes.h:16
std::vector< std::vector< int > > fWireDriftPdg
Definition: PointIdAlg.h:355
double Vx(const int i=0) const
Definition: MCParticle.h:221
Declaration of signal hit object.
#define Comment
std::vector< std::vector< float > > fWireDriftEdep
Definition: PointIdAlg.h:354
Contains all timing reference information for the detector.
static float particleRange2(const simb::MCParticle &particle)
Definition: PointIdAlg.h:340
Provides recob::Track data product.
double Vz(const int i=0) const
Definition: MCParticle.h:223
nnet::ModelInterface * fNNet
Definition: PointIdAlg.h:151
std::unique_ptr< tf::Graph > g
Definition: PointIdAlg.h:84
size_t fPatchSizeW
Definition: PointIdAlg.h:154
art::InputTag fSimulationProducerLabel
Definition: PointIdAlg.h:360
std::string fNNetModelFilePath
Definition: PointIdAlg.h:149
std::vector< int > const & wirePdg(size_t widx) const
Definition: PointIdAlg.h:303
double EndX() const
Definition: MCParticle.h:226
art::InputTag fHitProducerLabel
Definition: PointIdAlg.h:358
double Vy(const int i=0) const
Definition: MCParticle.h:222
Event finding and building.
TrainingDataAlg(const fhicl::ParameterSet &pset)
Definition: PointIdAlg.h:254
size_t fCurrentWireIdx
Definition: PointIdAlg.h:156
PointIdAlg(const fhicl::ParameterSet &pset)
Definition: PointIdAlg.h:104