Public Member Functions | Private Attributes | List of all members
nnet::TfModelInterface Class Reference

#include <PointIdAlg.h>

Inheritance diagram for nnet::TfModelInterface:
nnet::ModelInterface

Public Member Functions

 TfModelInterface (const char *modelFileName)
 
std::vector< std::vector< float > > Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1) override
 
std::vector< float > Run (std::vector< std::vector< float >> const &inp2d) override
 
- Public Member Functions inherited from nnet::ModelInterface
virtual ~ModelInterface ()
 

Private Attributes

std::unique_ptr< tf::Graphg
 

Additional Inherited Members

- Protected Member Functions inherited from nnet::ModelInterface
std::string findFile (const char *fileName) const
 

Detailed Description

Definition at line 75 of file PointIdAlg.h.

Constructor & Destructor Documentation

nnet::TfModelInterface::TfModelInterface ( const char *  modelFileName)

Definition at line 91 of file PointIdAlg.cxx.

92 {
93  g = tf::Graph::create(nnet::ModelInterface::findFile(modelFileName).c_str(),
94  {"cnn_output", "_netout"});
95  if (!g) { throw art::Exception(art::errors::Unknown) << "TF model failed."; }
96 
97  mf::LogInfo("TfModelInterface") << "TF model loaded.";
98 }
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:32
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
std::string findFile(const char *fileName) const
Definition: PointIdAlg.cxx:49
std::unique_ptr< tf::Graph > g
Definition: PointIdAlg.h:84

Member Function Documentation

std::vector< std::vector< float > > nnet::TfModelInterface::Run ( std::vector< std::vector< std::vector< float >>> const &  inps,
int  samples = -1 
)
overridevirtual

Reimplemented from nnet::ModelInterface.

Definition at line 102 of file PointIdAlg.cxx.

103 {
104  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty())
105  return std::vector<std::vector<float>>();
106 
107  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
108 
109  long long int rows = inps.front().size(), cols = inps.front().front().size();
110 
111  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({samples, rows, cols, 1}));
112  auto input_map = _x.tensor<float, 4>();
113  for (long long int s = 0; s < samples; ++s) {
114  const auto& sample = inps[s];
115  for (long long int r = 0; r < rows; ++r) {
116  const auto& row = sample[r];
117  for (long long int c = 0; c < cols; ++c) {
118  input_map(s, r, c, 0) = row[c];
119  }
120  }
121  }
122 
123  return g->run(_x);
124 }
struct vector vector
std::unique_ptr< tf::Graph > g
Definition: PointIdAlg.h:84
static QCString * s
Definition: config.cpp:1042
std::vector< float > nnet::TfModelInterface::Run ( std::vector< std::vector< float >> const &  inp2d)
overridevirtual

Implements nnet::ModelInterface.

Definition at line 128 of file PointIdAlg.cxx.

129 {
130  long long int rows = inp2d.size(), cols = inp2d.front().size();
131 
132  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, rows, cols, 1}));
133  auto input_map = _x.tensor<float, 4>();
134  for (long long int r = 0; r < rows; ++r) {
135  const auto& row = inp2d[r];
136  for (long long int c = 0; c < cols; ++c) {
137  input_map(0, r, c, 0) = row[c];
138  }
139  }
140 
141  auto out = g->run(_x);
142  if (!out.empty())
143  return out.front();
144  else
145  return std::vector<float>();
146 }
std::unique_ptr< tf::Graph > g
Definition: PointIdAlg.h:84

Member Data Documentation

std::unique_ptr<tf::Graph> nnet::TfModelInterface::g
private

Definition at line 84 of file PointIdAlg.h.


The documentation for this class was generated from the following files: