PointIdAlgKeras_tool.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 // Class: PointIdAlgKeras_tool (tool version of Keras model interface in PointIdAlg)
3 // Authors: D.Stefan (Dorota.Stefan@ncbj.gov.pl), from DUNE, CERN/NCBJ, since May 2016
4 // R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, since May 2016
5 // P.Plonski, from DUNE, WUT, since May 2016
6 // D.Smith, from LArIAT, BU, 2017: real data dump
7 // M.Wang, from DUNE, FNAL, 2020: tool version
8 ////////////////////////////////////////////////////////////////////////////////////////////////////
9 
11 
14 
15 #include <sys/stat.h>
16 
17 namespace PointIdAlgTools {
18 
19  class PointIdAlgKeras : public IPointIdAlg {
20  public:
21  explicit PointIdAlgKeras(const fhicl::ParameterSet& pset)
22  : PointIdAlgKeras(fhicl::Table<Config>(pset, {})())
23  {}
24  explicit PointIdAlgKeras(const Config& config);
25 
26  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) const override;
27  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
28  int samples = -1) const override;
29 
30  private:
31  std::unique_ptr<keras::KerasModel> m;
33  std::string findFile(const char* fileName) const;
34  };
35 
36  // ------------------------------------------------------
38  {
39  // ... Get common config vars
40  fNNetOutputs = config.NNetOutputs();
41  fPatchSizeW = config.PatchSizeW();
42  fPatchSizeD = config.PatchSizeD();
43  fCurrentWireIdx = 99999;
44  fCurrentScaledDrift = 99999;
45 
46  // ... Get "optional" config vars specific to tf interface
47  std::string s_cfgvr;
48  if (config.NNetModelFile(s_cfgvr)) { fNNetModelFilePath = s_cfgvr; }
49  else {
50  fNNetModelFilePath = "mycnn";
51  }
52 
53  if ((fNNetModelFilePath.length() > 5) &&
54  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 5, 5, ".nnet") == 0)) {
55  m = std::make_unique<keras::KerasModel>(findFile(fNNetModelFilePath.c_str()).c_str());
56  mf::LogInfo("PointIdAlgKeras") << "Keras model loaded.";
57  }
58  else {
59  mf::LogError("PointIdAlgKeras") << "File name extension not supported.";
60  }
61 
62  resizePatch();
63  }
64 
65  // ------------------------------------------------------
68  {
69  std::string fname_out;
70  cet::search_path sp("FW_SEARCH_PATH");
71  if (!sp.find_file(fileName, fname_out)) {
72  struct stat buffer;
73  if (stat(fileName, &buffer) == 0) { fname_out = fileName; }
74  else {
75  throw art::Exception(art::errors::NotFound) << "Could not find the model file " << fileName;
76  }
77  }
78  return fname_out;
79  }
80 
81  // ------------------------------------------------------
82  std::vector<float>
83  PointIdAlgKeras::Run(std::vector<std::vector<float>> const& inp2d) const
84  {
85  std::vector<std::vector<std::vector<float>>> inp3d;
86  inp3d.push_back(inp2d); // lots of copy, should add 2D to keras...
87 
88  keras::DataChunk2D sample;
89  sample.set_data(inp3d);
90  return m->compute_output(&sample);
91  }
92 
93  // ------------------------------------------------------
94  std::vector<std::vector<float>>
95  PointIdAlgKeras::Run(std::vector<std::vector<std::vector<float>>> const& inps, int samples) const
96  {
97 
98  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
99  return std::vector<std::vector<float>>();
100  }
101 
102  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
103 
104  std::vector<std::vector<float>> out;
105 
106  for (long long int s = 0; s < samples; ++s) {
107  std::vector<std::vector<std::vector<float>>> inp3d;
108  inp3d.push_back(inps[s]); // lots of copy, should add 2D to keras...
109 
110  keras::DataChunk* sample = new keras::DataChunk2D();
111  sample->set_data(inp3d); // and more copy...
112  out.push_back(m->compute_output(sample));
113  delete sample;
114  }
115 
116  return out;
117  }
118 
119 }
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
std::string string
Definition: nybbler.cc:12
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
struct vector vector
virtual void set_data(std::vector< std::vector< std::vector< float > > > const &)
Definition: keras_model.h:47
DataProviderAlg(const fhicl::ParameterSet &pset)
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::string findFile(const char *fileName) const
std::vector< float > Run(std::vector< std::vector< float >> const &inp2d) const override
std::vector< std::string > fNNetOutputs
Definition: IPointIdAlg.h:152
fileName
Definition: dumpTree.py:9
static Config * config
Definition: config.cpp:1054
virtual void set_data(std::vector< std::vector< std::vector< float > > > const &d)
Definition: keras_model.h:64
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::string find_file(std::string const &filename) const
Definition: search_path.cc:96
std::unique_ptr< keras::KerasModel > m
PointIdAlgKeras(const fhicl::ParameterSet &pset)
static QCString * s
Definition: config.cpp:1042