WaveformRecogTf_tool.cc
Go to the documentation of this file.
6 
7 #include <sys/stat.h>
8 
9 namespace wavrec_tool {
10 
12  public:
13  explicit WaveformRecogTf(const fhicl::ParameterSet& pset);
14 
15  std::vector<std::vector<float>> predictWaveformType(
16  const std::vector<std::vector<float>>&) const override;
17 
18  private:
19  std::unique_ptr<tf::Graph> g; // network graph
21  std::vector<std::string> fNNetOutputPattern;
22  };
23 
24  // ------------------------------------------------------
26  {
27  fNNetModelFilePath = pset.get<std::string>("NNetModelFile", "mymodel.pb");
29  pset.get<std::vector<std::string>>("NNetOutputPattern", {"cnn_output", "dense_3"});
30  if ((fNNetModelFilePath.length() > 3) &&
31  (fNNetModelFilePath.compare(fNNetModelFilePath.length() - 3, 3, ".pb") == 0)) {
33  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
34  mf::LogInfo("WaveformRecogTf") << "TF model loaded.";
35  }
36  else {
37  mf::LogError("WaveformRecogTf") << "File name extension not supported.";
38  }
39 
41  }
42 
43  // ------------------------------------------------------
44  std::vector<std::vector<float>>
45  WaveformRecogTf::predictWaveformType(const std::vector<std::vector<float>>& waveforms) const
46  {
47  if (waveforms.empty() || waveforms.front().empty()) {
48  return std::vector<std::vector<float>>();
49  }
50 
51  long long int samples = waveforms.size(), numtcks = waveforms.front().size();
52 
53  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, numtcks, 1}));
54  auto input_map = _x.tensor<float, 3>();
55  for (long long int s = 0; s < samples; ++s) {
56  const auto& wvfrm = waveforms[s];
57  for (long long int t = 0; t < numtcks; ++t) {
58  input_map(s, t, 0) = wvfrm[t];
59  }
60  }
61 
62  return g->run(_x);
63  }
64 
65 }
WaveformRecogTf(const fhicl::ParameterSet &pset)
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
std::string string
Definition: nybbler.cc:12
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
struct vector vector
MaybeLogger_< ELseverityLevel::ELsev_error, false > LogError
std::string findFile(const char *fileName) const
T get(std::string const &key) const
Definition: ParameterSet.h:271
std::vector< std::vector< float > > predictWaveformType(const std::vector< std::vector< float >> &) const override
std::vector< std::string > fNNetOutputPattern
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:32
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::unique_ptr< tf::Graph > g
static QCString * s
Definition: config.cpp:1042
void setupWaveRecRoiParams(const fhicl::ParameterSet &pset)