Public Member Functions | Static Public Member Functions | Public Attributes | Private Member Functions | Private Attributes | List of all members
tf::CTPGraph Class Reference

#include <CTPGraph.h>

Public Member Functions

 ~CTPGraph ()
 
std::vector< float > run (const std::vector< std::vector< float > > &x)
 
std::vector< std::vector< std::vector< float > > > run (const std::vector< std::vector< std::vector< float > > > &x)
 
std::vector< std::vector< std::vector< float > > > run (const std::vector< tensorflow::Tensor > &x)
 

Static Public Member Functions

static std::unique_ptr< CTPGraphcreate (const char *graph_file_name, const std::vector< std::string > &outputs={}, int ninputs=1, int noutputs=1)
 

Public Attributes

int n_inputs = 1
 
int n_outputs = 1
 

Private Member Functions

 CTPGraph (const char *graph_file_name, const std::vector< std::string > &outputs, bool &success, int ninputs, int noutputs)
 Not-throwing constructor. More...
 

Private Attributes

tensorflow::Session * fSession
 
std::vector< std::stringfInputNames
 
std::vector< std::stringfOutputNames
 

Detailed Description

Definition at line 26 of file CTPGraph.h.

Constructor & Destructor Documentation

tf::CTPGraph::~CTPGraph ( )

Definition at line 117 of file CTPGraph.cc.

118 {
119  auto status = fSession->Close();
120  if (!status.ok()) {
121  std::cout << "tf::CTPGraph::dtor: " << "Close failed." << std::endl;
122  }
123  delete fSession;
124 }
tensorflow::Session * fSession
Definition: CTPGraph.h:54
QTextStream & endl(QTextStream &s)
tf::CTPGraph::CTPGraph ( const char *  graph_file_name,
const std::vector< std::string > &  outputs,
bool success,
int  ninputs,
int  noutputs 
)
private

Not-throwing constructor.

Definition at line 18 of file CTPGraph.cc.

19 {
20 
21 // std::cout << "Starting to build the graph" << std::endl;
22  success = false; // until all is done correctly
23 
24  n_inputs = ninputs;
25  n_outputs = noutputs;
26 
27  // Force tf to only use a single core so it doesn't eat batch farms
28  tensorflow::SessionOptions options;
29  tensorflow::ConfigProto &config = options.config;
30  config.set_inter_op_parallelism_threads(1);
31  config.set_intra_op_parallelism_threads(1);
32  config.set_use_per_session_threads(false);
33 
34 // std::cout << "Starting tf session" << std::endl;
35  auto status = tensorflow::NewSession(options, &fSession);
36  if (!status.ok())
37  {
38  std::cout << status.ToString() << std::endl;
39  return;
40  }
41 
42 // std::cout << "Session started... reading network architecture" << std::endl;
43  tensorflow::GraphDef graph_def;
44  status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
45  if (!status.ok())
46  {
47  std::cout << status.ToString() << std::endl;
48  return;
49  }
50 
51 // std::cout << "Extracting input names" << std::endl;
52  size_t ng = graph_def.node().size();
53 
54  // fill input names (TODO: generic)
55  for(int i=0; i<n_inputs; ++i)
56  {
57  fInputNames.push_back(graph_def.node()[i].name());
58  }
59 
60 // std::cout << "Extracting output names" << std::endl;
61  // last node as output if no specific name provided
62  if (outputs.empty())
63  {
64  for(int i=n_outputs; i>0; --i)
65  {
66  fOutputNames.push_back(graph_def.node()[ng - i].name());
67  }
68 
69  }
70  else // or last nodes with names containing provided strings
71  {
73  for (size_t n = 0; n < ng; ++n)
74  {
75  name = graph_def.node()[n].name();
76  auto pos = name.find("/");
77  if (pos != std::string::npos) { basename = name.substr(0, pos); }
78  else { continue; }
79 
80  bool found = false;
81  for (const auto & s : outputs)
82  {
83  if (name.find(s) != std::string::npos) { found = true; break; }
84  }
85  if (found)
86  {
87  if (!last.empty() && (basename != current))
88  {
89  fOutputNames.push_back(last);
90  }
91  current = basename;
92  last = name;
93  }
94  }
95  if (!last.empty()) { fOutputNames.push_back(last); }
96  }
97  if (fOutputNames.empty())
98  {
99  std::cout << "Output nodes not found in the graph." << std::endl;
100  return;
101  }
102 
103 // std::cout << "About to create graph" << std::endl;
104 
105  status = fSession->Create(graph_def);
106  if (!status.ok())
107  {
108  std::cout << status.ToString() << std::endl;
109  return;
110  }
111 
112  success = true; // ok, graph loaded from the file
113 
114 // std::cout << "Graph success? " << success << std::endl;
115 }
static QCString name
Definition: declinfo.cpp:673
std::string string
Definition: nybbler.cc:12
tensorflow::Session * fSession
Definition: CTPGraph.h:54
static Config * config
Definition: config.cpp:1054
std::void_t< T > n
std::vector< std::string > fInputNames
Definition: CTPGraph.h:56
static Entry * current
std::vector< std::string > fOutputNames
Definition: CTPGraph.h:57
int n_inputs
Definition: CTPGraph.h:29
int n_outputs
Definition: CTPGraph.h:30
static QCString * s
Definition: config.cpp:1042
QTextStream & endl(QTextStream &s)

Member Function Documentation

static std::unique_ptr<CTPGraph> tf::CTPGraph::create ( const char *  graph_file_name,
const std::vector< std::string > &  outputs = {},
int  ninputs = 1,
int  noutputs = 1 
)
inlinestatic

Definition at line 32 of file CTPGraph.h.

32  {}, int ninputs = 1, int noutputs = 1)
33  {
34  bool success;
35  std::unique_ptr<CTPGraph> ptr(new CTPGraph(graph_file_name, outputs, success, ninputs, noutputs));
36  if (success) { return ptr; }
37  else { return nullptr; }
38  }
CTPGraph(const char *graph_file_name, const std::vector< std::string > &outputs, bool &success, int ninputs, int noutputs)
Not-throwing constructor.
Definition: CTPGraph.cc:18
std::vector<float> tf::CTPGraph::run ( const std::vector< std::vector< float > > &  x)
std::vector< std::vector< std::vector< float > > > tf::CTPGraph::run ( const std::vector< std::vector< std::vector< float > > > &  x)

Definition at line 128 of file CTPGraph.cc.

130 {
131  // Number of objects to classify
132  const unsigned int nSamples = input.size();
133 
134  // There are two inputs to our network...
135  // 1) The 100 element dE/dx array
136  const unsigned int nEls = input.front().at(0).size();
137  tensorflow::Tensor dEdxTensor(tensorflow::DT_FLOAT,tensorflow::TensorShape({nSamples,nEls,1}));
138 
139  // 2) The 7 additional classification variables
140  const unsigned int nVars = input.front().at(1).size();
141  // NB: this input doesn't need a defined depth as it doesn't use convolutions
142  tensorflow::Tensor varsTensor(tensorflow::DT_FLOAT,tensorflow::TensorShape({nSamples,nVars}));
143 
144 // std::cout << "Input shapes: " << input.size() << ", " << input.front().size() << ", " << input.front().at(0).size() << ", " << input.front().at(1).size() << std::endl;
145 
146  // Fill the tensors
147  auto dedxInputMap = dEdxTensor.tensor<float,3>();
148  auto varsInputMap = varsTensor.tensor<float,2>();
149  for(unsigned int s = 0; s < nSamples; ++s){
150  // dEdx first
151  for(unsigned int e = 0; e < nEls; ++e){
152 // std::cout << "Adding element " << e << " with value " << input.at(s).at(0).at(e) << std::endl;
153  dedxInputMap(s,e,0) = input.at(s).at(0).at(e);
154  }
155  // And the other variables
156  for(unsigned int v = 0; v < nVars; ++v){
157 // std::cout << "Adding variable " << v << " with value " << input.at(s).at(1).at(v) << std::endl;
158  varsInputMap(s,v) = input.at(s).at(1).at(v);
159  }
160  }
161 
162  std::vector<tensorflow::Tensor> inputTensors;
163  inputTensors.push_back(dEdxTensor);
164  inputTensors.push_back(varsTensor);
165 
166 // std::cout << "Input tensors arranged inside the interface" << std::endl;
167 
168  return run(inputTensors);
169 }
std::vector< float > run(const std::vector< std::vector< float > > &x)
const double e
static int input(void)
Definition: code.cpp:15695
static QCString * s
Definition: config.cpp:1042
std::vector< std::vector< std::vector< float > > > tf::CTPGraph::run ( const std::vector< tensorflow::Tensor > &  x)

Definition at line 173 of file CTPGraph.cc.

174 {
175  // Pair up the inputs with their names in the network
176  std::vector< std::pair<std::string, tensorflow::Tensor> > inputs;
177  for(int i=0; i<n_inputs; ++i){
178 // std::cout << "Pairing up input with name " << fInputNames[i] << " with input number " << i << std::endl;
179  inputs.push_back({fInputNames[i], x[i]});
180  }
181 
182  // The output from TF has dimensions nOutputs, nSamples, nNodes
183  std::vector<tensorflow::Tensor> outputs;
184  auto status = fSession->Run(inputs, fOutputNames, {}, &outputs);
185 
186 // std::cout << "Sorting out the outputs inside the interface" << std::endl;
187 
188  // Dimensions we want to return are nSamples, nOutputs, nNodes
189  std::vector< std::vector< std::vector<float> > > result;
190  if(status.ok()){
191  // The first dimension of the output vector is the number of outputs
192  const unsigned int nOut = outputs.size();
193  unsigned int nSamples = 0;
194  // Get the number of samples and check it is the same for all outputs
195  for(unsigned int o = 0; o < nOut; ++o){
196  if (o == 0){
197  nSamples = outputs[o].dim_size(0);
198  }
199  else if (nSamples != outputs[o].dim_size(0))
200  {
201  throw std::string("TF outputs size inconsistent.");
202  }
203  }
204 
205  result.resize(nSamples,std::vector< std::vector<float> >(nOut));
206  for(unsigned int s = 0; s < nSamples; ++s){
207  for(unsigned int o = 0; o < nOut; ++o){
208  // Get the 2D tensor (nSamples,nNodes) for this output
209  auto output_map = outputs[o].tensor<float,2>();
210  const unsigned int nNodes = outputs[o].dim_size(1);
211  result[s][o].resize(nNodes);
212  for(unsigned int n = 0; n < nNodes; ++n){
213  result[s][o][n] = output_map(s,n);
214  }
215  }
216  }
217  }
218  else{
219  std::cout << "Processing error in Tensorflow. Returning empty output." << std::endl;
220  std::cout << "Error = " << status.ToString() << std::endl;
221  }
222 
223  return result;
224 
225 }
static QCString result
std::string string
Definition: nybbler.cc:12
struct vector vector
tensorflow::Session * fSession
Definition: CTPGraph.h:54
std::void_t< T > n
std::vector< std::string > fInputNames
Definition: CTPGraph.h:56
std::vector< std::string > fOutputNames
Definition: CTPGraph.h:57
int n_inputs
Definition: CTPGraph.h:29
static QCString * s
Definition: config.cpp:1042
QTextStream & endl(QTextStream &s)

Member Data Documentation

std::vector< std::string > tf::CTPGraph::fInputNames
private

Definition at line 56 of file CTPGraph.h.

std::vector< std::string > tf::CTPGraph::fOutputNames
private

Definition at line 57 of file CTPGraph.h.

tensorflow::Session* tf::CTPGraph::fSession
private

Definition at line 54 of file CTPGraph.h.

int tf::CTPGraph::n_inputs = 1

Definition at line 29 of file CTPGraph.h.

int tf::CTPGraph::n_outputs = 1

Definition at line 30 of file CTPGraph.h.


The documentation for this class was generated from the following files: