PointIdAlg.h
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 // Class: PointIdAlg
3 // Authors: D.Stefan (Dorota.Stefan@ncbj.gov.pl), from DUNE, CERN/NCBJ, since May 2016
4 // R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, since May 2016
5 // P.Plonski, from DUNE, WUT, since May 2016
6 //
7 //
8 // Point Identification Algorithm
9 //
10 // Run CNN or MLP trained to classify a point in 2D projection. Various features can be
11 // recognized, depending on the net model/weights used.
12 //
13 ////////////////////////////////////////////////////////////////////////////////////////////////////
14 
15 #ifndef PointIdAlg_h
16 #define PointIdAlg_h
17 
18 // Framework includes
21 
22 // LArSoft includes
23 #include "canvas/Persistency/Common/FindManyP.h"
28 
32 
33 // ROOT & C++
34 #include <memory>
35 
36 namespace nnet
37 {
38  class ModelInterface;
39  class KerasModelInterface;
40  class TfModelInterface;
41  class PointIdAlg;
42  class TrainingDataAlg;
43 }
44 
45 /// Interface class for various classifier models. Now MLP (NetMaker) and CNN (Keras with
46 /// simple cpp interface) are supported. Will add interface to Protobuf as soon as Tensorflow
47 /// may be used from UPS.
49 {
50 public:
51  virtual ~ModelInterface(void) { }
52 
53  virtual std::vector<float> Run(std::vector< std::vector<float> > const & inp2d) = 0;
54  virtual std::vector< std::vector<float> > Run(std::vector< std::vector< std::vector<float> > > const & inps, int samples = -1);
55 
56 protected:
57  ModelInterface(void) { }
58 
59  std::string findFile(const char* fileName) const;
60 };
61 // ------------------------------------------------------
62 
64 {
65 public:
66  KerasModelInterface(const char* modelFileName);
67 
68  std::vector<float> Run(std::vector< std::vector<float> > const & inp2d) override;
69 
70 private:
71  keras::KerasModel m; // network model
72 };
73 // ------------------------------------------------------
74 
76 {
77 public:
78  TfModelInterface(const char* modelFileName);
79 
80  std::vector< std::vector<float> > Run(std::vector< std::vector< std::vector<float> > > const & inps, int samples = -1) override;
81  std::vector<float> Run(std::vector< std::vector<float> > const & inp2d) override;
82 
83 private:
84  std::unique_ptr<tf::Graph> g; // network graph
85 };
86 // ------------------------------------------------------
87 
89 {
90 public:
91 
93  {
94  using Name = fhicl::Name;
96 
97  fhicl::Atom<std::string> NNetModelFile {
98  Name("NNetModelFile"), Comment("Neural net model to apply.")
99  };
101  Name("NNetOutputs"), Comment("Labels of the network outputs.")
102  };
103 
105  Name("PatchSizeW"), Comment("How many wires in patch.")
106  };
107 
109  Name("PatchSizeD"), Comment("How many downsampled ADC entries in patch")
110  };
111  };
112 
114  PointIdAlg(fhicl::Table<Config>(pset, {})())
115  {}
116 
117  PointIdAlg(const Config& config);
118 
119  ~PointIdAlg(void) override;
120 
121  /// network output labels
122  std::vector< std::string > const & outputLabels(void) const { return fNNetOutputs; }
123 
124  /// calculate single-value prediction (2-class probability) for [wire, drift] point
125  float predictIdValue(unsigned int wire, float drift, size_t outIdx = 0) const;
126 
127  /// calculate multi-class probabilities for [wire, drift] point
128  std::vector<float> predictIdVector(unsigned int wire, float drift) const;
129 
130  std::vector< std::vector<float> > predictIdVectors(std::vector< std::pair<unsigned int, float> > points) const;
131 
132  static std::vector<float> flattenData2D(std::vector< std::vector<float> > const & patch);
133 
134  std::vector< std::vector<float> > const & patchData2D(void) const { return fWireDriftPatch; }
135  std::vector<float> patchData1D(void) const { return flattenData2D(fWireDriftPatch); } // flat vector made of the patch data, wire after wire
136 
137  bool isInsideFiducialRegion(unsigned int wire, float drift) const;
138 
139  /// test if wire/drift coordinates point to the current patch (so maybe the cnn output
140  /// does not need to be recalculated)
141  bool isCurrentPatch(unsigned int wire, float drift) const;
142 
143  /// test if two wire/drift coordinates point to the same patch
144  bool isSamePatch(unsigned int wire1, float drift1, unsigned int wire2, float drift2) const;
145 
146 private:
148  std::vector< std::string > fNNetOutputs;
150 
151  mutable std::vector< std::vector<float> > fWireDriftPatch; // patch data around the identified point
152  size_t fPatchSizeW, fPatchSizeD;
153 
154  mutable size_t fCurrentWireIdx, fCurrentScaledDrift;
155  bool bufferPatch(size_t wire, float drift, std::vector< std::vector<float> > & patch) const
156  {
157  if (fDownscaleFullView)
158  {
159  size_t sd = (size_t)(drift / fDriftWindow);
160  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == sd))
161  return true; // still within the current position
162 
163  fCurrentWireIdx = wire; fCurrentScaledDrift = sd;
164 
165  return patchFromDownsampledView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
166  }
167  else
168  {
169  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == drift))
170  return true; // still within the current position
171 
172  fCurrentWireIdx = wire; fCurrentScaledDrift = drift;
173 
174  return patchFromOriginalView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
175  }
176  }
177  bool bufferPatch(size_t wire, float drift) const { return bufferPatch(wire, drift, fWireDriftPatch); }
178  void resizePatch(void);
179 
180  void deleteNNet(void) { if (fNNet) delete fNNet; fNNet = 0; }
181 };
182 // ------------------------------------------------------
183 // ------------------------------------------------------
184 // ------------------------------------------------------
185 
187 {
188 public:
189 
190  enum EMask
191  {
192  kNone = 0,
193  kPdgMask = 0x00000FFF, // pdg code mask
194  kTypeMask = 0x0000F000, // track type mask
195  kVtxMask = 0xFFFF0000 // vertex flags
196  };
197 
198  enum ETrkType
199  {
200  kDelta = 0x1000, // delta electron
201  kMichel = 0x2000, // Michel electron
202  kPriEl = 0x4000, // primary electron
203  kPriMu = 0x8000 // primary muon
204  };
205 
206  enum EVtxId
207  {
208  kNuNC = 0x0010000, kNuCC = 0x0020000, kNuPri = 0x0040000, // nu interaction type
209  kNuE = 0x0100000, kNuMu = 0x0200000, kNuTau = 0x0400000, // nu flavor
210  kHadr = 0x1000000, // hadronic inelastic scattering
211  kPi0 = 0x2000000, // pi0 produced in this vertex
212  kDecay = 0x4000000, // point of particle decay
213  kConv = 0x8000000, // gamma conversion
214  kElectronEnd = 0x10000000,// clear end of an electron
215  kElastic = 0x20000000,// Elastic scattering
216  kInelastic = 0x40000000 // Inelastic scattering
217  };
218 
220  {
221  using Name = fhicl::Name;
223 
225  Name("WireLabel"),
226  Comment("Tag of recob::Wire.")
227  };
228 
230  Name("HitLabel"),
231  Comment("Tag of recob::Hit.")
232  };
233 
235  Name("TrackLabel"),
236  Comment("Tag of recob::Track.")
237  };
238 
240  Name("SimulationLabel"),
241  Comment("Tag of simulation producer.")
242  };
243 
244  fhicl::Atom< bool > SaveVtxFlags {
245  Name("SaveVtxFlags"),
246  Comment("Include (or not) vertex info in PDG map.")
247  };
248 
250  Name("AdcDelayTicks"),
251  Comment("ADC pulse peak delay in ticks (non-zero for not deconvoluted waveforms).")
252  };
253  };
254 
256  TrainingDataAlg(fhicl::Table<Config>(pset, {})())
257  {}
258 
259  TrainingDataAlg(const Config& config);
260 
261  ~TrainingDataAlg(void) override;
262 
263  void reconfigure(const Config& config);
264 
265  bool saveSimInfo() const { return fSaveSimInfo; }
266 
267  bool setEventData(const art::Event& event, // collect & downscale ADC's, charge deposits, pdg labels
268  unsigned int plane, unsigned int tpc, unsigned int cryo);
269 
270  bool setDataEventData(const art::Event& event, // collect & downscale ADC's, charge deposits, pdg labels
271  unsigned int plane, unsigned int tpc, unsigned int cryo);
272 
273 
274  bool findCrop(float max_e_cut, unsigned int & w0, unsigned int & w1, unsigned int & d0, unsigned int & d1) const;
275 
276  double getEdepTot(void) const { return fEdepTot; } // [GeV]
277  std::vector<float> const & wireEdep(size_t widx) const { return fWireDriftEdep[widx]; }
278  std::vector<int> const & wirePdg(size_t widx) const { return fWireDriftPdg[widx]; }
279 
280 protected:
281 
282  void resizeView(size_t wires, size_t drifts) override;
283 
284 private:
285 
286  struct WireDrift // used to find MCParticle start/end 2D projections
287  {
288  size_t Wire;
289  int Drift;
290  int TPC;
291  int Cryo;
292  };
293 
294  WireDrift getProjection(const TLorentzVector& tvec, unsigned int plane) const;
295 
296  bool setWireEdepsAndLabels(
297  std::vector<float> const & edeps,
298  std::vector<int> const & pdgs,
299  size_t wireIdx);
300 
301  void collectVtxFlags(
302  std::unordered_map< size_t, std::unordered_map< int, int > > & wireToDriftToVtxFlags,
303  const std::unordered_map< int, const simb::MCParticle* > & particleMap,
304  unsigned int plane) const;
305 
306  static float particleRange2(const simb::MCParticle & particle)
307  {
308  float dx = particle.EndX() - particle.Vx();
309  float dy = particle.EndY() - particle.Vy();
310  float dz = particle.EndZ() - particle.Vz();
311  return dx*dx + dy*dy + dz*dz;
312  }
313  bool isElectronEnd(
314  const simb::MCParticle & particle,
315  const std::unordered_map< int, const simb::MCParticle* > & particleMap) const;
316 
317  bool isMuonDecaying(
318  const simb::MCParticle & particle,
319  const std::unordered_map< int, const simb::MCParticle* > & particleMap) const;
320 
321  double fEdepTot; // [GeV]
322  std::vector< std::vector<float> > fWireDriftEdep;
323  std::vector< std::vector<int> > fWireDriftPdg;
324 
331 
332  unsigned int fAdcDelay;
333 
334  std::vector<size_t> fEventsPerBin;
335 };
336 // ------------------------------------------------------
337 // ------------------------------------------------------
338 // ------------------------------------------------------
339 
340 #endif
bool bufferPatch(size_t wire, float drift, std::vector< std::vector< float > > &patch) const
Definition: PointIdAlg.h:155
std::vector< float > const & wireEdep(size_t widx) const
Definition: PointIdAlg.h:277
std::vector< float > patchData1D(void) const
Definition: PointIdAlg.h:135
double EndZ() const
Definition: MCParticle.h:227
bool saveSimInfo() const
Definition: PointIdAlg.h:265
unsigned int fAdcDelay
Definition: PointIdAlg.h:332
std::vector< std::vector< int > > fWireDriftPdg
Definition: PointIdAlg.h:323
std::vector< std::vector< float > > const & patchData2D(void) const
Definition: PointIdAlg.h:134
void deleteNNet(void)
Definition: PointIdAlg.h:180
std::string string
Definition: nybbler.cc:12
bool bufferPatch(size_t wire, float drift) const
Definition: PointIdAlg.h:177
art::InputTag fTrackModuleLabel
Definition: PointIdAlg.h:327
Declaration of signal hit object.
std::vector< std::vector< float > > fWireDriftPatch
Definition: PointIdAlg.h:151
struct vector vector
virtual std::vector< float > Run(std::vector< std::vector< float > > const &inp2d)=0
std::vector< size_t > fEventsPerBin
Definition: PointIdAlg.h:334
Particle class.
double EndY() const
Definition: MCParticle.h:226
no compression
Definition: RawTypes.h:9
std::vector< std::string > fNNetOutputs
Definition: PointIdAlg.h:148
nnet::ModelInterface * fNNet
Definition: PointIdAlg.h:149
art::InputTag fWireProducerLabel
Definition: PointIdAlg.h:325
std::vector< std::string > const & outputLabels(void) const
network output labels
Definition: PointIdAlg.h:122
static Config * config
Definition: config.cpp:1054
std::vector< std::vector< float > > fWireDriftEdep
Definition: PointIdAlg.h:322
keras::KerasModel m
Definition: PointIdAlg.h:71
Provides recob::Track data product.
double getEdepTot(void) const
Definition: PointIdAlg.h:276
AdcCodeMitigator::Name Name
double Vx(const int i=0) const
Definition: MCParticle.h:220
#define Comment
virtual ~ModelInterface(void)
Definition: PointIdAlg.h:51
std::string findFile(const char *fileName) const
Definition: PointIdAlg.cxx:52
static float particleRange2(const simb::MCParticle &particle)
Definition: PointIdAlg.h:306
double Vz(const int i=0) const
Definition: MCParticle.h:222
size_t fPatchSizeW
Definition: PointIdAlg.h:152
art::InputTag fSimulationProducerLabel
Definition: PointIdAlg.h:328
std::string fNNetModelFilePath
Definition: PointIdAlg.h:147
std::unique_ptr< tf::Graph > g
Definition: PointIdAlg.h:84
std::vector< int > const & wirePdg(size_t widx) const
Definition: PointIdAlg.h:278
double EndX() const
Definition: MCParticle.h:225
art::InputTag fHitProducerLabel
Definition: PointIdAlg.h:326
double Vy(const int i=0) const
Definition: MCParticle.h:221
Event finding and building.
TrainingDataAlg(const fhicl::ParameterSet &pset)
Definition: PointIdAlg.h:255
size_t fCurrentWireIdx
Definition: PointIdAlg.h:154
PointIdAlg(const fhicl::ParameterSet &pset)
Definition: PointIdAlg.h:113