PointIdAlgTf_tool.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 // Class: PointIdAlgTf_tool (tool version of TensorFlow model interface in PointIdAlg)
3 // Authors: D.Stefan (Dorota.Stefan@ncbj.gov.pl), from DUNE, CERN/NCBJ, since May 2016
4 // R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, since May 2016
5 // P.Plonski, from DUNE, WUT, since May 2016
6 // D.Smith, from LArIAT, BU, 2017: real data dump
7 // M.Wang, from DUNE, FNAL, 2020: tool version
8 ////////////////////////////////////////////////////////////////////////////////////////////////////
9 
11 
15 
16 #include <sys/stat.h>
17 
18 namespace PointIdAlgTools {
19 
20  class PointIdAlgTf : public IPointIdAlg {
21  public:
22  explicit PointIdAlgTf(fhicl::Table<Config> const& table);
23 
24  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) const override;
25  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
26  int samples = -1) const override;
27 
28  protected:
29  std::string findFile(const char* fileName) const;
30 
31  private:
32  std::unique_ptr<tf::Graph> g; // network graph
33  std::vector<std::string> fNNetOutputPattern;
35  };
36 
37  // ------------------------------------------------------
39  {
40  // ... Get common config vars
41  fNNetOutputs = table().NNetOutputs();
42  fPatchSizeW = table().PatchSizeW();
43  fPatchSizeD = table().PatchSizeD();
44  fCurrentWireIdx = 99999;
45  fCurrentScaledDrift = 99999;
46 
47  // ... Get "optional" config vars specific to tf interface
48  std::string s_cfgvr;
49  if (table().NNetModelFile(s_cfgvr)) { fNNetModelFilePath = s_cfgvr; }
50  else {
51  fNNetModelFilePath = "mycnn";
52  }
53  std::vector<std::string> vs_cfgvr;
54  if (table().NNetOutputPattern(vs_cfgvr)) { fNNetOutputPattern = vs_cfgvr; }
55  else {
56  fNNetOutputPattern = {"cnn_output", "_netout"};
57  }
58 
59  if ((fNNetModelFilePath.length() > 3) &&
60  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 3, 3, ".pb") == 0)) {
62  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
63  mf::LogInfo("PointIdAlgTf") << "TF model loaded.";
64  }
65  else {
66  mf::LogError("PointIdAlgTf") << "File name extension not supported.";
67  }
68 
69  resizePatch();
70  }
71 
72  // ------------------------------------------------------
74  PointIdAlgTf::findFile(const char* fileName) const
75  {
76  std::string fname_out;
77  cet::search_path sp("FW_SEARCH_PATH");
78  if (!sp.find_file(fileName, fname_out)) {
79  struct stat buffer;
80  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
81  else {
82  throw art::Exception(art::errors::NotFound) << "Could not find the model file " << fileName;
83  }
84  }
85  return fname_out;
86  }
87 
88  // ------------------------------------------------------
89  std::vector<float>
90  PointIdAlgTf::Run(std::vector<std::vector<float>> const& inp2d) const
91  {
92  long long int rows = inp2d.size(), cols = inp2d.front().size();
93 
94  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, rows, cols, 1}));
95  auto input_map = _x.tensor<float, 4>();
96  for (long long int r = 0; r < rows; ++r) {
97  const auto& row = inp2d[r];
98  for (long long int c = 0; c < cols; ++c) {
99  input_map(0, r, c, 0) = row[c];
100  }
101  }
102 
103  auto out = g->run(_x);
104  if (!out.empty())
105  return out.front();
106  else
107  return std::vector<float>();
108  }
109 
110  // ------------------------------------------------------
111  std::vector<std::vector<float>>
112  PointIdAlgTf::Run(std::vector<std::vector<std::vector<float>>> const& inps, int samples) const
113  {
114 
115  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
116  return std::vector<std::vector<float>>();
117  }
118 
119  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
120 
121  long long int rows = inps.front().size(), cols = inps.front().front().size();
122 
123  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1}));
124  auto input_map = _x.tensor<float, 4>();
125  for (long long int s = 0; s < samples; ++s) {
126  const auto& sample = inps[s];
127  for (long long int r = 0; r < rows; ++r) {
128  const auto& row = sample[r];
129  for (long long int c = 0; c < cols; ++c) {
130  input_map(s, r, c, 0) = row[c];
131  }
132  }
133  }
134  return g->run(_x);
135  }
136 
137 }
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
std::string string
Definition: nybbler.cc:12
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
struct vector vector
std::string findFile(const char *fileName) const
std::vector< std::string > fNNetOutputPattern
DataProviderAlg(const fhicl::ParameterSet &pset)
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::vector< std::string > fNNetOutputs
Definition: IPointIdAlg.h:152
fileName
Definition: dumpTree.py:9
std::vector< float > Run(std::vector< std::vector< float >> const &inp2d) const override
PointIdAlgTf(fhicl::Table< Config > const &table)
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:32
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::unique_ptr< tf::Graph > g
std::string find_file(std::string const &filename) const
Definition: search_path.cc:96
static QCString * s
Definition: config.cpp:1042