Public Member Functions | Private Member Functions | Private Attributes | List of all members
PointIdAlgTools::PointIdAlgKeras Class Reference
Inheritance diagram for PointIdAlgTools::PointIdAlgKeras:
PointIdAlgTools::IPointIdAlg img::DataProviderAlg

Public Member Functions

 PointIdAlgKeras (const fhicl::ParameterSet &pset)
 
 PointIdAlgKeras (const Config &config)
 
std::vector< float > Run (std::vector< std::vector< float >> const &inp2d) const override
 
std::vector< std::vector< float > > Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1) const override
 
- Public Member Functions inherited from PointIdAlgTools::IPointIdAlg
virtual ~IPointIdAlg () noexcept=default
 
float predictIdValue (unsigned int wire, float drift, size_t outIdx=0)
 
std::vector< float > predictIdVector (unsigned int wire, float drift)
 
std::vector< std::vector< float > > predictIdVectors (const std::vector< std::pair< unsigned int, float >> &points)
 
std::vector< std::string > const & outputLabels (void) const
 
bool isInsideFiducialRegion (unsigned int wire, float drift) const
 
- Public Member Functions inherited from img::DataProviderAlg
 DataProviderAlg (const fhicl::ParameterSet &pset)
 
 DataProviderAlg (const Config &config)
 
virtual ~DataProviderAlg ()
 
bool setWireDriftData (const detinfo::DetectorClocksData &clock_data, const detinfo::DetectorPropertiesData &det_prop, const std::vector< recob::Wire > &wires, unsigned int plane, unsigned int tpc, unsigned int cryo)
 
std::vector< float > const & wireData (size_t widx) const
 
std::vector< std::vector< float > > getPatch (size_t wire, float drift, size_t patchSizeW, size_t patchSizeD) const
 
float getPixelOrZero (int wire, int drift) const
 
double getAdcSum () const
 
size_t getAdcArea () const
 
float poolMax (int wire, int drift, size_t r=0) const
 Pool max value in a patch around the wire/drift pixel. More...
 
unsigned int Cryo () const
 Pool sum of pixels in a patch around the wire/drift pixel. More...
 
unsigned int TPC () const
 
unsigned int Plane () const
 
unsigned int NWires () const
 
unsigned int NScaledDrifts () const
 
unsigned int NCachedDrifts () const
 
unsigned int DriftWindow () const
 
float ZeroLevel () const
 Level of zero ADC after scaling. More...
 
double LifetimeCorrection (detinfo::DetectorClocksData const &clock_data, detinfo::DetectorPropertiesData const &det_prop, double tick) const
 

Private Member Functions

std::string findFile (const char *fileName) const
 

Private Attributes

std::unique_ptr< keras::KerasModelm
 
std::string fNNetModelFilePath
 

Additional Inherited Members

- Public Types inherited from img::DataProviderAlg
enum  EDownscaleMode { kMax = 1, kMaxMean = 2, kMean = 3 }
 
- Protected Member Functions inherited from PointIdAlgTools::IPointIdAlg
bool bufferPatch (size_t wire, float drift, std::vector< std::vector< float >> &patch)
 
bool bufferPatch (size_t wire, float drift)
 
void resizePatch (void)
 
- Protected Member Functions inherited from img::DataProviderAlg
std::vector< float > downscaleMax (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscaleMaxMean (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscaleMean (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscale (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
size_t getDriftIndex (float drift) const
 
std::optional< std::vector< float > > setWireData (std::vector< float > const &adc, size_t wireIdx) const
 
bool patchFromDownsampledView (size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
 
bool patchFromOriginalView (size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
 
virtual DataProviderAlgView resizeView (detinfo::DetectorClocksData const &clock_data, detinfo::DetectorPropertiesData const &det_prop, size_t wires, size_t drifts)
 
- Protected Attributes inherited from PointIdAlgTools::IPointIdAlg
std::vector< std::stringfNNetOutputs
 
size_t fPatchSizeW
 
size_t fPatchSizeD
 
std::vector< std::vector< float > > fWireDriftPatch
 
size_t fCurrentWireIdx
 
size_t fCurrentScaledDrift
 
- Protected Attributes inherited from img::DataProviderAlg
DataProviderAlgView fAlgView
 
EDownscaleMode fDownscaleMode
 
size_t fDriftWindow
 
bool fDownscaleFullView
 
float fDriftWindowInv
 
calo::CalorimetryAlg fCalorimetryAlg
 
geo::GeometryCore const * fGeometry
 

Detailed Description

Definition at line 19 of file PointIdAlgKeras_tool.cc.

Constructor & Destructor Documentation

PointIdAlgTools::PointIdAlgKeras::PointIdAlgKeras ( const fhicl::ParameterSet pset)
inlineexplicit

Definition at line 21 of file PointIdAlgKeras_tool.cc.

23  {}
PointIdAlgKeras(const fhicl::ParameterSet &pset)
PointIdAlgTools::PointIdAlgKeras::PointIdAlgKeras ( const Config config)
explicit

Definition at line 37 of file PointIdAlgKeras_tool.cc.

38  {
39  // ... Get common config vars
40  fNNetOutputs = config.NNetOutputs();
41  fPatchSizeW = config.PatchSizeW();
42  fPatchSizeD = config.PatchSizeD();
43  fCurrentWireIdx = 99999;
44  fCurrentScaledDrift = 99999;
45 
46  // ... Get "optional" config vars specific to tf interface
47  std::string s_cfgvr;
48  if (config.NNetModelFile(s_cfgvr)) { fNNetModelFilePath = s_cfgvr; }
49  else {
50  fNNetModelFilePath = "mycnn";
51  }
52 
53  if ((fNNetModelFilePath.length() > 5) &&
54  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 5, 5, ".nnet") == 0)) {
55  m = std::make_unique<keras::KerasModel>(findFile(fNNetModelFilePath.c_str()).c_str());
56  mf::LogInfo("PointIdAlgKeras") << "Keras model loaded.";
57  }
58  else {
59  mf::LogError("PointIdAlgKeras") << "File name extension not supported.";
60  }
61 
62  resizePatch();
63  }
std::string string
Definition: nybbler.cc:12
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::string findFile(const char *fileName) const
std::vector< std::string > fNNetOutputs
Definition: IPointIdAlg.h:152
static Config * config
Definition: config.cpp:1054
std::unique_ptr< keras::KerasModel > m

Member Function Documentation

std::string PointIdAlgTools::PointIdAlgKeras::findFile ( const char *  fileName) const
private

Definition at line 67 of file PointIdAlgKeras_tool.cc.

68  {
69  std::string fname_out;
70  cet::search_path sp("FW_SEARCH_PATH");
71  if (!sp.find_file(fileName, fname_out)) {
72  struct stat buffer;
73  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
74  else {
75  throw art::Exception(art::errors::NotFound) << "Could not find the model file " << fileName;
76  }
77  }
78  return fname_out;
79  }
std::string string
Definition: nybbler.cc:12
fileName
Definition: dumpTree.py:9
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::vector< float > PointIdAlgTools::PointIdAlgKeras::Run ( std::vector< std::vector< float >> const &  inp2d) const
overridevirtual

Implements PointIdAlgTools::IPointIdAlg.

Definition at line 83 of file PointIdAlgKeras_tool.cc.

84  {
85  std::vector<std::vector<std::vector<float>>> inp3d;
86  inp3d.push_back(inp2d); // lots of copy, should add 2D to keras...
87 
88  keras::DataChunk2D sample;
89  sample.set_data(inp3d);
90  return m->compute_output(&sample);
91  }
virtual void set_data(std::vector< std::vector< std::vector< float > > > const &d)
Definition: keras_model.h:64
std::unique_ptr< keras::KerasModel > m
std::vector< std::vector< float > > PointIdAlgTools::PointIdAlgKeras::Run ( std::vector< std::vector< std::vector< float >>> const &  inps,
int  samples = -1 
) const
overridevirtual

Implements PointIdAlgTools::IPointIdAlg.

Definition at line 95 of file PointIdAlgKeras_tool.cc.

96  {
97 
98  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
99  return std::vector<std::vector<float>>();
100  }
101 
102  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
103 
104  std::vector<std::vector<float>> out;
105 
106  for (long long int s = 0; s < samples; ++s) {
107  std::vector<std::vector<std::vector<float>>> inp3d;
108  inp3d.push_back(inps[s]); // lots of copy, should add 2D to keras...
109 
110  keras::DataChunk* sample = new keras::DataChunk2D();
111  sample->set_data(inp3d); // and more copy...
112  out.push_back(m->compute_output(sample));
113  delete sample;
114  }
115 
116  return out;
117  }
virtual void set_data(std::vector< std::vector< std::vector< float > > > const &)
Definition: keras_model.h:47
std::unique_ptr< keras::KerasModel > m
static QCString * s
Definition: config.cpp:1042

Member Data Documentation

std::string PointIdAlgTools::PointIdAlgKeras::fNNetModelFilePath
private

Definition at line 32 of file PointIdAlgKeras_tool.cc.

std::unique_ptr<keras::KerasModel> PointIdAlgTools::PointIdAlgKeras::m
private

Definition at line 31 of file PointIdAlgKeras_tool.cc.


The documentation for this class was generated from the following file: