IPointIdAlg.h
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 // Class: IPointIdAlg (interface for tool version of PointIdAlg)
3 // Authors: D.Stefan (Dorota.Stefan@ncbj.gov.pl), from DUNE, CERN/NCBJ, since May 2016
4 // R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, since May 2016
5 // P.Plonski, from DUNE, WUT, since May 2016
6 // M.Wang, from DUNE, FNAL, 2020: tool version
7 //
8 //
9 // Point Identification Algorithm
10 //
11 // Run CNN or MLP trained to classify a point in 2D projection. Various features can be
12 // recognized, depending on the net model/weights used.
13 //
14 ////////////////////////////////////////////////////////////////////////////////////////////////////
15 
16 #ifndef IPointIdAlg_H
17 #define IPointIdAlg_H
18 
23 
24 namespace PointIdAlgTools {
25  class IPointIdAlg : virtual public img::DataProviderAlg {
26  public:
28  using Name = fhicl::Name;
30 
32  Comment("Neural net model to apply.")};
34  Comment("Labels of the network outputs.")};
36  Name("NNetOutputPattern"),
37  Comment("Pattern to use when searching for network outputs.")};
38  fhicl::Atom<unsigned int> PatchSizeW{Name("PatchSizeW"), Comment("How many wires in patch.")};
40  Comment("How many downsampled ADC entries in patch")};
42  Comment("PointID algorithm tool type")};
44  Name("TritonModelName"),
45  Comment("Model directory name in repository of Nvidia Triton inference server"),
46  "mycnn"};
48  Comment("URL of Nvidia Triton inference server"),
49  "localhost:8001"};
51  Name("TritonModelVersion"),
52  Comment("Version number of Nvidia Triton inference server model"),
53  ""};
55  Name("TritonVerbose"),
56  Comment("Verbosity switch for Nvidia Triton inference server client"),
57  false};
59  Name("TritonAllowedTries"),
60  Comment("Number of allowed attempts for Nvidia Triton inference server client"),
61  1};
62  };
63  virtual ~IPointIdAlg() noexcept = default;
64 
65  // Define standard art tool interface
66  //virtual void configure(const fhicl::ParameterSet& pset) = 0;
67 
68  virtual std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) const = 0;
69  virtual std::vector<std::vector<float>> Run(
70  std::vector<std::vector<std::vector<float>>> const& inps,
71  int samples = -1) const = 0;
72 
73  // calculate single-value prediction (2-class probability) for [wire, drift] point
74  float
75  predictIdValue(unsigned int wire, float drift, size_t outIdx = 0)
76  {
77  float result = 0.;
78 
79  if (!bufferPatch(wire, drift)) {
80  mf::LogError("PointIdAlg") << "Patch buffering failed.";
81  return result;
82  }
83 
84  auto out = Run(fWireDriftPatch);
85  if (!out.empty()) { result = out[outIdx]; }
86  else {
87  mf::LogError("PointIdAlg") << "Problem with applying model to input.";
88  }
89 
90  return result;
91  }
92 
93  // Calculate multi-class probabilities for [wire, drift] point
94  std::vector<float>
95  predictIdVector(unsigned int wire, float drift)
96  {
97  std::vector<float> result;
98 
99  if (!bufferPatch(wire, drift)) {
100  mf::LogError("PointIdAlg") << "Patch buffering failed.";
101  return result;
102  }
103 
104  result = Run(fWireDriftPatch);
105  if (result.empty()) { mf::LogError("PointIdAlg") << "Problem with applying model to input."; }
106 
107  return result;
108  }
109 
110  // Calculate multi-class probabilities for a vector of [wire, drift] points
111  std::vector<std::vector<float>>
112  predictIdVectors(const std::vector<std::pair<unsigned int, float>>& points)
113  {
114  if (points.empty()) { return std::vector<std::vector<float>>(); }
115 
116  std::vector<std::vector<std::vector<float>>> inps(
117  points.size(),
118  std::vector<std::vector<float>>(fPatchSizeW, std::vector<float>(fPatchSizeD)));
119  for (size_t i = 0; i < points.size(); ++i) {
120  unsigned int wire = points[i].first;
121  float drift = points[i].second;
122  if (!bufferPatch(wire, drift, inps[i])) {
123  throw cet::exception("PointIdAlg") << "Patch buffering failed" << std::endl;
124  }
125  }
126 
127  return Run(inps);
128  }
129 
130  std::vector<std::string> const&
131  outputLabels(void) const
132  {
133  return fNNetOutputs;
134  }
135  bool
136  isInsideFiducialRegion(unsigned int wire, float drift) const
137  {
138  size_t marginW = fPatchSizeW / 8; // fPatchSizeX/2 will make patch always completely filled
139  size_t marginD = fPatchSizeD / 8;
140 
141  size_t scaledDrift = (size_t)(drift / fDriftWindow);
142  if ((wire >= marginW) && (wire < fAlgView.fNWires - marginW) && (scaledDrift >= marginD) &&
143  (scaledDrift < fAlgView.fNScaledDrifts - marginD)) {
144  return true;
145  }
146  else {
147  return false;
148  }
149  }
150 
151  protected:
152  std::vector<std::string> fNNetOutputs;
154  std::vector<std::vector<float>> fWireDriftPatch; // patch data around the identified point
156 
157  bool
158  bufferPatch(size_t wire, float drift, std::vector<std::vector<float>>& patch)
159  {
160  if (fDownscaleFullView) {
161  size_t sd = (size_t)(drift / fDriftWindow);
162  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == sd))
163  return true; // still within the current position
164 
165  fCurrentWireIdx = wire;
166  fCurrentScaledDrift = sd;
167 
168  return patchFromDownsampledView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
169  }
170  else {
171  if ((fCurrentWireIdx == wire) && (fCurrentScaledDrift == drift))
172  return true; // still within the current position
173 
174  fCurrentWireIdx = wire;
175  fCurrentScaledDrift = drift;
176 
177  return patchFromOriginalView(wire, drift, fPatchSizeW, fPatchSizeD, patch);
178  }
179  }
180 
181  bool
182  bufferPatch(size_t wire, float drift)
183  {
184  return bufferPatch(wire, drift, fWireDriftPatch);
185  }
186 
187  void
189  {
190  fWireDriftPatch.resize(fPatchSizeW);
191  for (auto& r : fWireDriftPatch)
192  r.resize(fPatchSizeD);
193  }
194  };
195 }
196 
197 #endif
fhicl::Atom< std::string > TritonURL
Definition: IPointIdAlg.h:47
std::vector< std::vector< float > > fWireDriftPatch
Definition: IPointIdAlg.h:154
static QCString result
fhicl::OptionalAtom< std::string > NNetModelFile
Definition: IPointIdAlg.h:31
struct vector vector
ChannelGroupService::Name Name
fhicl::Atom< std::string > TritonModelVersion
Definition: IPointIdAlg.h:50
bool bufferPatch(size_t wire, float drift)
Definition: IPointIdAlg.h:182
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::vector< std::string > fNNetOutputs
Definition: IPointIdAlg.h:152
fhicl::OptionalAtom< std::string > ToolType
Definition: IPointIdAlg.h:41
float predictIdValue(unsigned int wire, float drift, size_t outIdx=0)
Definition: IPointIdAlg.h:75
fhicl::Atom< std::string > TritonModelName
Definition: IPointIdAlg.h:43
fhicl::Atom< unsigned int > PatchSizeW
Definition: IPointIdAlg.h:38
std::vector< float > predictIdVector(unsigned int wire, float drift)
Definition: IPointIdAlg.h:95
fhicl::Atom< unsigned int > PatchSizeD
Definition: IPointIdAlg.h:39
DataProviderAlgView fAlgView
bool bufferPatch(size_t wire, float drift, std::vector< std::vector< float >> &patch)
Definition: IPointIdAlg.h:158
fhicl::Atom< bool > TritonVerbose
Definition: IPointIdAlg.h:54
bool patchFromOriginalView(size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
fhicl::Sequence< std::string > NNetOutputs
Definition: IPointIdAlg.h:33
std::vector< std::vector< float > > predictIdVectors(const std::vector< std::pair< unsigned int, float >> &points)
Definition: IPointIdAlg.h:112
#define Comment
virtual ~IPointIdAlg() noexcept=default
bool isInsideFiducialRegion(unsigned int wire, float drift) const
Definition: IPointIdAlg.h:136
std::vector< std::string > const & outputLabels(void) const
Definition: IPointIdAlg.h:131
fhicl::OptionalSequence< std::string > NNetOutputPattern
Definition: IPointIdAlg.h:35
fhicl::Atom< unsigned > TritonAllowedTries
Definition: IPointIdAlg.h:58
bool patchFromDownsampledView(size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
virtual std::vector< float > Run(std::vector< std::vector< float >> const &inp2d) const =0
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)