IWaveformRecog.h
Go to the documentation of this file.
1 #ifndef IWaveformRecog_H
2 #define IWaveformRecog_H
3 
6 #include <sys/stat.h>
7 
8 namespace wavrec_tool {
9  class IWaveformRecog {
10  public:
11  virtual ~IWaveformRecog() noexcept = default;
12 
13  // Calculate multi-class probabilities for waveform
14  virtual std::vector<std::vector<float>> predictWaveformType(
15  const std::vector<std::vector<float>>&) const = 0;
16 
17  // ---------------------------------------------------------------------
18  // Return a vector of booleans of the same size as the input waveform.
19  // The value of each element of the vector represents whether the
20  // corresponding time bin of the waveform is in an ROI or not.
21  // ---------------------------------------------------------------------
22  std::vector<bool>
23  findROI(const std::vector<float>& adcin) const
24  {
25  std::vector<bool> bvec(fWaveformSize, false);
26  if (adcin.size() != fWaveformSize) { return bvec; }
27 
28  std::vector<std::vector<float>> predv = scanWaveform(adcin);
29 
30  // .. set to true all bins in the output vector that are in windows identified as signals
31  int j1;
32  for (unsigned int i = 0; i < fNumStrides; i++) {
33  j1 = i * fStrideLength;
34  if (predv[i][0] > fCnnPredCut) { std::fill_n(bvec.begin() + j1, fWindowSize, true); }
35  }
36  // .. last window is a special case
37  if (predv[fNumStrides][0] > fCnnPredCut) {
38  j1 = fNumStrides * fStrideLength;
39  std::fill_n(bvec.begin() + j1, fLastWindowSize, true);
40  }
41  return bvec;
42  }
43 
44  // -------------------------------------------------------------
45  // Return a vector of floats of the same size as the input
46  // waveform. The value in each bin represents the probability
47  // whether that bin is in an ROI or not
48  // -------------------------------------------------------------
49  std::vector<float>
50  predROI(const std::vector<float>& adcin) const
51  {
52  std::vector<float> fvec(fWaveformSize, 0.);
53  if (adcin.size() != fWaveformSize) { return fvec; }
54 
55  std::vector<std::vector<float>> predv = scanWaveform(adcin);
56 
57  // .. set value in each bin of output vector to the prediction for the window it is in
58  int j1;
59  for (unsigned int i = 0; i < fNumStrides; i++) {
60  j1 = i * fStrideLength;
61  std::fill_n(fvec.begin() + j1, fWindowSize, predv[i][0]);
62  }
63  // .. last window is a special case
64  j1 = fNumStrides * fStrideLength;
65  std::fill_n(fvec.begin() + j1, fLastWindowSize, predv[fNumStrides][0]);
66  return fvec;
67  }
68 
69  protected:
71  findFile(const char* fileName) const
72  {
73  std::string fname_out;
74  cet::search_path sp("FW_SEARCH_PATH");
75  if (!sp.find_file(fileName, fname_out)) {
76  struct stat buffer;
77  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
78  else {
80  << "Could not find the model file " << fileName;
81  }
82  }
83  return fname_out;
84  }
85 
86  void
88  {
89  fCnnPredCut = pset.get<float>("CnnPredCut", 0.5);
90  fWaveformSize = pset.get<unsigned int>("WaveformSize", 0); // 6000
91  std::string fMeanFilename = pset.get<std::string>("MeanFilename", "");
92  std::string fScaleFilename = pset.get<std::string>("ScaleFilename", "");
93 
94  // ... load the mean and scale (std) vectors
95  if (!fMeanFilename.empty() && !fScaleFilename.empty()) {
96  float val;
97  std::ifstream meanfile(findFile(fMeanFilename.c_str()).c_str());
98  if (meanfile.is_open()) {
99  while (meanfile >> val)
100  meanvec.push_back(val);
101  meanfile.close();
102  if (meanvec.size() != fWaveformSize) {
103  throw cet::exception("WaveformRecogTf_tool")
104  << "vector of mean values does not match waveform size, exiting" << std::endl;
105  }
106  }
107  else {
108  throw cet::exception("WaveformRecogTf_tool")
109  << "failed opening StdScaler mean file, exiting" << std::endl;
110  }
111  std::ifstream scalefile(findFile(fScaleFilename.c_str()).c_str());
112  if (scalefile.is_open()) {
113  while (scalefile >> val)
114  scalevec.push_back(val);
115  scalefile.close();
116  if (scalevec.size() != fWaveformSize) {
117  throw cet::exception("WaveformRecogTf_tool")
118  << "vector of scale values does not match waveform size, exiting" << std::endl;
119  }
120  }
121  else {
122  throw cet::exception("WaveformRecogTf_tool")
123  << "failed opening StdScaler scale file, exiting" << std::endl;
124  }
125  }
126  else {
127  fCnnMean = pset.get<float>("CnnMean", 0.);
128  fCnnScale = pset.get<float>("CnnScale", 1.);
129  meanvec.resize(fWaveformSize);
130  std::fill(meanvec.begin(), meanvec.end(), fCnnMean);
131  scalevec.resize(fWaveformSize);
132  std::fill(scalevec.begin(), scalevec.end(), fCnnScale);
133  }
134 
135  fWindowSize = pset.get<unsigned int>("ScanWindowSize", 0); // 200
136  fStrideLength = pset.get<unsigned int>("StrideLength", 0); // 150
137 
138  if (fWaveformSize > 0 && fWindowSize > 0) {
139  float dmn =
140  fWaveformSize - fWindowSize; // dist between trail edge of 1st win & last data point
141  fNumStrides = std::ceil(dmn / float(fStrideLength)); // # strides to scan entire waveform
142  unsigned int overshoot = fNumStrides * fStrideLength + fWindowSize - fWaveformSize;
143  fLastWindowSize = fWindowSize - overshoot;
144  unsigned int numwindows = fNumStrides + 1;
145  std::cout << " !!!!! WaveformRoiFinder: WindowSize = " << fWindowSize
146  << ", StrideLength = " << fStrideLength
147  << ", dmn/StrideLength = " << dmn / fStrideLength << std::endl;
148  std::cout << " dmn = " << dmn << ", NumStrides = " << fNumStrides
149  << ", overshoot = " << overshoot << ", LastWindowSize = " << fLastWindowSize
150  << ", numwindows = " << numwindows << std::endl;
151  }
152  }
153 
154  private:
155  std::vector<float> scalevec;
156  std::vector<float> meanvec;
157  float fCnnMean;
158  float fCnnScale;
159  float fCnnPredCut;
160  unsigned int fWaveformSize; // Full waveform size
161  unsigned int fWindowSize; // Scan window size
162  unsigned int fStrideLength; // Offset (in #time ticks) between scan windows
163  unsigned int fNumStrides;
164  unsigned int fLastWindowSize;
165 
166  std::vector<std::vector<float>>
167  scanWaveform(const std::vector<float>& adcin) const
168  {
169  // .. rescale input waveform for CNN
170  std::vector<float> adc(fWaveformSize);
171  for (size_t itck = 0; itck < fWaveformSize; ++itck) {
172  adc[itck] = (adcin[itck] - meanvec[itck]) / scalevec[itck];
173  }
174 
175  // .. create a vector of windows
176  std::vector<std::vector<float>> wwv(fNumStrides + 1, std::vector<float>(fWindowSize, 0.));
177 
178  // .. fill each window with adc values
179  unsigned int j1, j2, k;
180  for (unsigned int i = 0; i < fNumStrides; i++) {
181  j1 = i * fStrideLength;
182  j2 = j1 + fWindowSize;
183  k = 0;
184  for (unsigned int j = j1; j < j2; j++) {
185  wwv[i][k] = adc[j];
186  k++;
187  }
188  }
189  // .. last window is a special case
190  j1 = fNumStrides * fStrideLength;
191  j2 = j1 + fLastWindowSize;
192  k = 0;
193  for (unsigned int j = j1; j < j2; j++) {
194  wwv[fNumStrides][k] = adc[j];
195  k++;
196  }
197 
198  // ... use waveform recognition CNN to perform inference on each window
199  return predictWaveformType(wwv);
200  }
201  };
202 }
203 
204 #endif
std::vector< float > scalevec
std::string string
Definition: nybbler.cc:12
virtual ~IWaveformRecog() noexcept=default
struct vector vector
int16_t adc
Definition: CRTFragment.hh:202
std::string findFile(const char *fileName) const
std::vector< float > meanvec
std::vector< std::vector< float > > scanWaveform(const std::vector< float > &adcin) const
fileName
Definition: dumpTree.py:9
std::vector< bool > findROI(const std::vector< float > &adcin) const
T get(std::string const &key) const
Definition: ParameterSet.h:271
virtual std::vector< std::vector< float > > predictWaveformType(const std::vector< std::vector< float >> &) const =0
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
def fill(s)
Definition: translator.py:93
std::string find_file(std::string const &filename) const
Definition: search_path.cc:96
std::vector< float > predROI(const std::vector< float > &adcin) const
void setupWaveRecRoiParams(const fhicl::ParameterSet &pset)
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)