Public Member Functions | Protected Member Functions | Private Attributes | List of all members
PointIdAlgTools::PointIdAlgTf Class Reference
Inheritance diagram for PointIdAlgTools::PointIdAlgTf:
PointIdAlgTools::IPointIdAlg img::DataProviderAlg

Public Member Functions

 PointIdAlgTf (fhicl::Table< Config > const &table)
 
std::vector< float > Run (std::vector< std::vector< float >> const &inp2d) const override
 
std::vector< std::vector< float > > Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1) const override
 
- Public Member Functions inherited from PointIdAlgTools::IPointIdAlg
virtual ~IPointIdAlg () noexcept=default
 
float predictIdValue (unsigned int wire, float drift, size_t outIdx=0)
 
std::vector< float > predictIdVector (unsigned int wire, float drift)
 
std::vector< std::vector< float > > predictIdVectors (const std::vector< std::pair< unsigned int, float >> &points)
 
std::vector< std::string > const & outputLabels (void) const
 
bool isInsideFiducialRegion (unsigned int wire, float drift) const
 
- Public Member Functions inherited from img::DataProviderAlg
 DataProviderAlg (const fhicl::ParameterSet &pset)
 
 DataProviderAlg (const Config &config)
 
virtual ~DataProviderAlg ()
 
bool setWireDriftData (const detinfo::DetectorClocksData &clock_data, const detinfo::DetectorPropertiesData &det_prop, const std::vector< recob::Wire > &wires, unsigned int plane, unsigned int tpc, unsigned int cryo)
 
std::vector< float > const & wireData (size_t widx) const
 
std::vector< std::vector< float > > getPatch (size_t wire, float drift, size_t patchSizeW, size_t patchSizeD) const
 
float getPixelOrZero (int wire, int drift) const
 
double getAdcSum () const
 
size_t getAdcArea () const
 
float poolMax (int wire, int drift, size_t r=0) const
 Pool max value in a patch around the wire/drift pixel. More...
 
unsigned int Cryo () const
 Pool sum of pixels in a patch around the wire/drift pixel. More...
 
unsigned int TPC () const
 
unsigned int Plane () const
 
unsigned int NWires () const
 
unsigned int NScaledDrifts () const
 
unsigned int NCachedDrifts () const
 
unsigned int DriftWindow () const
 
float ZeroLevel () const
 Level of zero ADC after scaling. More...
 
double LifetimeCorrection (detinfo::DetectorClocksData const &clock_data, detinfo::DetectorPropertiesData const &det_prop, double tick) const
 

Protected Member Functions

std::string findFile (const char *fileName) const
 
- Protected Member Functions inherited from PointIdAlgTools::IPointIdAlg
bool bufferPatch (size_t wire, float drift, std::vector< std::vector< float >> &patch)
 
bool bufferPatch (size_t wire, float drift)
 
void resizePatch (void)
 
- Protected Member Functions inherited from img::DataProviderAlg
std::vector< float > downscaleMax (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscaleMaxMean (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscaleMean (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscale (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
size_t getDriftIndex (float drift) const
 
std::optional< std::vector< float > > setWireData (std::vector< float > const &adc, size_t wireIdx) const
 
bool patchFromDownsampledView (size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
 
bool patchFromOriginalView (size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
 
virtual DataProviderAlgView resizeView (detinfo::DetectorClocksData const &clock_data, detinfo::DetectorPropertiesData const &det_prop, size_t wires, size_t drifts)
 

Private Attributes

std::unique_ptr< tf::Graphg
 
std::vector< std::stringfNNetOutputPattern
 
std::string fNNetModelFilePath
 

Additional Inherited Members

- Public Types inherited from img::DataProviderAlg
enum  EDownscaleMode { kMax = 1, kMaxMean = 2, kMean = 3 }
 
- Protected Attributes inherited from PointIdAlgTools::IPointIdAlg
std::vector< std::stringfNNetOutputs
 
size_t fPatchSizeW
 
size_t fPatchSizeD
 
std::vector< std::vector< float > > fWireDriftPatch
 
size_t fCurrentWireIdx
 
size_t fCurrentScaledDrift
 
- Protected Attributes inherited from img::DataProviderAlg
DataProviderAlgView fAlgView
 
EDownscaleMode fDownscaleMode
 
size_t fDriftWindow
 
bool fDownscaleFullView
 
float fDriftWindowInv
 
calo::CalorimetryAlg fCalorimetryAlg
 
geo::GeometryCore const * fGeometry
 

Detailed Description

Definition at line 20 of file PointIdAlgTf_tool.cc.

Constructor & Destructor Documentation

PointIdAlgTools::PointIdAlgTf::PointIdAlgTf ( fhicl::Table< Config > const &  table)
explicit

Definition at line 38 of file PointIdAlgTf_tool.cc.

38  : img::DataProviderAlg(table())
39  {
40  // ... Get common config vars
41  fNNetOutputs = table().NNetOutputs();
42  fPatchSizeW = table().PatchSizeW();
43  fPatchSizeD = table().PatchSizeD();
44  fCurrentWireIdx = 99999;
45  fCurrentScaledDrift = 99999;
46 
47  // ... Get "optional" config vars specific to tf interface
48  std::string s_cfgvr;
49  if (table().NNetModelFile(s_cfgvr)) { fNNetModelFilePath = s_cfgvr; }
50  else {
51  fNNetModelFilePath = "mycnn";
52  }
53  std::vector<std::string> vs_cfgvr;
54  if (table().NNetOutputPattern(vs_cfgvr)) { fNNetOutputPattern = vs_cfgvr; }
55  else {
56  fNNetOutputPattern = {"cnn_output", "_netout"};
57  }
58 
59  if ((fNNetModelFilePath.length() > 3) &&
60  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 3, 3, ".pb") == 0)) {
62  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
63  mf::LogInfo("PointIdAlgTf") << "TF model loaded.";
64  }
65  else {
66  mf::LogError("PointIdAlgTf") << "File name extension not supported.";
67  }
68 
69  resizePatch();
70  }
std::string string
Definition: nybbler.cc:12
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
std::string findFile(const char *fileName) const
std::vector< std::string > fNNetOutputPattern
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::vector< std::string > fNNetOutputs
Definition: IPointIdAlg.h:152
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:32
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::unique_ptr< tf::Graph > g

Member Function Documentation

std::string PointIdAlgTools::PointIdAlgTf::findFile ( const char *  fileName) const
protected

Definition at line 74 of file PointIdAlgTf_tool.cc.

75  {
76  std::string fname_out;
77  cet::search_path sp("FW_SEARCH_PATH");
78  if (!sp.find_file(fileName, fname_out)) {
79  struct stat buffer;
80  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
81  else {
82  throw art::Exception(art::errors::NotFound) << "Could not find the model file " << fileName;
83  }
84  }
85  return fname_out;
86  }
std::string string
Definition: nybbler.cc:12
fileName
Definition: dumpTree.py:9
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::vector< float > PointIdAlgTools::PointIdAlgTf::Run ( std::vector< std::vector< float >> const &  inp2d) const
overridevirtual

Implements PointIdAlgTools::IPointIdAlg.

Definition at line 90 of file PointIdAlgTf_tool.cc.

91  {
92  long long int rows = inp2d.size(), cols = inp2d.front().size();
93 
94  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, rows, cols, 1}));
95  auto input_map = _x.tensor<float, 4>();
96  for (long long int r = 0; r < rows; ++r) {
97  const auto& row = inp2d[r];
98  for (long long int c = 0; c < cols; ++c) {
99  input_map(0, r, c, 0) = row[c];
100  }
101  }
102 
103  auto out = g->run(_x);
104  if (!out.empty())
105  return out.front();
106  else
107  return std::vector<float>();
108  }
std::unique_ptr< tf::Graph > g
std::vector< std::vector< float > > PointIdAlgTools::PointIdAlgTf::Run ( std::vector< std::vector< std::vector< float >>> const &  inps,
int  samples = -1 
) const
overridevirtual

Implements PointIdAlgTools::IPointIdAlg.

Definition at line 112 of file PointIdAlgTf_tool.cc.

113  {
114 
115  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
116  return std::vector<std::vector<float>>();
117  }
118 
119  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
120 
121  long long int rows = inps.front().size(), cols = inps.front().front().size();
122 
123  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1}));
124  auto input_map = _x.tensor<float, 4>();
125  for (long long int s = 0; s < samples; ++s) {
126  const auto& sample = inps[s];
127  for (long long int r = 0; r < rows; ++r) {
128  const auto& row = sample[r];
129  for (long long int c = 0; c < cols; ++c) {
130  input_map(s, r, c, 0) = row[c];
131  }
132  }
133  }
134  return g->run(_x);
135  }
std::unique_ptr< tf::Graph > g
static QCString * s
Definition: config.cpp:1042

Member Data Documentation

std::string PointIdAlgTools::PointIdAlgTf::fNNetModelFilePath
private

Definition at line 34 of file PointIdAlgTf_tool.cc.

std::vector<std::string> PointIdAlgTools::PointIdAlgTf::fNNetOutputPattern
private

Definition at line 33 of file PointIdAlgTf_tool.cc.

std::unique_ptr<tf::Graph> PointIdAlgTools::PointIdAlgTf::g
private

Definition at line 32 of file PointIdAlgTf_tool.cc.


The documentation for this class was generated from the following file: