tf_graph.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 // Class: Graph
3 // Authors: R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, Sept. 2017
4 // P.Plonski, from DUNE, WUT, Sept. 2017
5 // S.Alonso-Monsalve, from DUNE, CERN, Aug. 2018
6 // Iterface to run Tensorflow graph saved to a file. First attempts, quite functional.
7 //
8 ////////////////////////////////////////////////////////////////////////////////////////////////////
9 
10 #include "tf_graph.h"
11 
13 #include "tensorflow/core/platform/env.h"
14 
15 #include "tensorflow/core/public/session_options.h"
16 
17 // -------------------------------------------------------------------
18 tf::Graph::Graph(const char* graph_file_name, const std::vector<std::string> & outputs, bool & success, int ninputs, int noutputs)
19 {
20  success = false; // until all is done correctly
21 
22  n_inputs = ninputs;
23  n_outputs = noutputs;
24 
25  // Force tf to only use a single core so it doesn't eat batch farms
26  tensorflow::SessionOptions options;
27  tensorflow::ConfigProto &config = options.config;
28  config.set_inter_op_parallelism_threads(1);
29  config.set_intra_op_parallelism_threads(1);
30  config.set_use_per_session_threads(false);
31 
32  auto status = tensorflow::NewSession(options, &fSession);
33  if (!status.ok())
34  {
35  std::cout << status.ToString() << std::endl;
36  return;
37  }
38 
39  tensorflow::GraphDef graph_def;
40  status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
41  if (!status.ok())
42  {
43  std::cout << status.ToString() << std::endl;
44  return;
45  }
46 
47  size_t ng = graph_def.node().size();
48 
49  // fill input names (TODO: generic)
50  for(int i=0; i<n_inputs; ++i)
51  {
52  fInputNames.push_back(graph_def.node()[i].name());
53  }
54 
55  /*
56  std::cout << "Input names: " << std::endl;
57  for(int i=0; i<n_inputs; ++i)
58  std::cout << fInputNames[i] << std::endl;
59  */
60 
61  // last node as output if no specific name provided
62  if (outputs.empty())
63  {
64  for(int i=n_outputs; i>0; --i)
65  {
66  fOutputNames.push_back(graph_def.node()[ng - i].name());
67  }
68 
69  /*
70  std::cout << "Output names: " << std::endl;
71  for(int i=0; i<n_outputs; ++i)
72  std::cout << fOutputNames[i] << std::endl;
73  */
74  }
75  else // or last nodes with names containing provided strings
76  {
78  for (size_t n = 0; n < ng; ++n)
79  {
80  name = graph_def.node()[n].name();
81  auto pos = name.find("/");
82  if (pos != std::string::npos) { basename = name.substr(0, pos); }
83  else { continue; }
84 
85  bool found = false;
86  for (const auto & s : outputs)
87  {
88  if (name.find(s) != std::string::npos) { found = true; break; }
89  }
90  if (found)
91  {
92  if (!last.empty() && (basename != current))
93  {
94  fOutputNames.push_back(last);
95  }
96  current = basename;
97  last = name;
98  }
99  }
100  if (!last.empty()) { fOutputNames.push_back(last); }
101  }
102  if (fOutputNames.empty())
103  {
104  std::cout << "Output nodes not found in the graph." << std::endl;
105  return;
106  }
107 
108  status = fSession->Create(graph_def);
109  if (!status.ok())
110  {
111  std::cout << status.ToString() << std::endl;
112  return;
113  }
114 
115  success = true; // ok, graph loaded from the file
116 }
117 
119 {
120  if ( ! fSession->Close().ok() ) {
121  std::cout << "tf::Graph::dtor: " << "Close failed." << std::endl;
122  }
123  //fSession->Close();
124  delete fSession;
125 }
126 
127 // -------------------------------------------------------------------
128 
129 std::vector< std::vector< std::vector<float> > > tf::Graph::run(
130  const std::vector< std::vector< std::vector< std::vector<float> > > > & x,
131  long long int samples)
132 {
133  if ((samples == 0) || x.empty() || x.front().empty() || x.front().front().empty() || x.front().front().front().empty())
134  return std::vector< std::vector< std::vector<float> > >();
135 
136  if ((samples == -1) || (samples > (long long int)x.size())) { samples = x.size(); }
137 
138  long long int
139  rows = x.front().size(),
140  cols = x.front().front().size(),
141  depth = x.front().front().front().size();
142 
143  std::vector< tensorflow::Tensor > _x;
144 
145  // Single-input network
146  if (n_inputs == 1)
147  {
148  _x.push_back(tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({ samples, rows, cols, depth })));
149  auto input_map = _x[0].tensor<float, 4>();
150  for (long long int s = 0; s < samples; ++s) {
151  const auto & sample = x[s];
152  for (long long int r = 0; r < rows; ++r) {
153  const auto & row = sample[r];
154  for (long long int c = 0; c < cols; ++c) {
155  const auto & col = row[c];
156  for (long long int d = 0; d < depth; ++d) {
157  input_map(s, r, c, d) = col[d];
158  }
159  }
160  }
161  }
162  }
163  // Multi-input network
164  else
165  {
166  for(int i=0; i<depth; ++i){
167  _x.push_back(tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({ samples, rows, cols, 1 })));
168  }
169 
170  //tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({ samples, rows, cols, depth }));
171 
172  for(int view=0; view<depth; ++view){
173  auto input_map = _x[view].tensor<float, 4>();
174  for (long long int s = 0; s < samples; ++s) {
175  const auto & sample = x[s];
176  for (long long int r = 0; r < rows; ++r) {
177  const auto & row = sample[r];
178  for (long long int c = 0; c < cols; ++c) {
179  const auto & col = row[c];
180  long long int d = view;
181  input_map(s, r, c, 0) = col[d];
182  }
183  }
184  }
185  }
186  }
187 
188  return run(_x);
189 }
190 
191 // -------------------------------------------------------------------
192 
193 std::vector< std::vector< std::vector< float > > > tf::Graph::run(const std::vector< tensorflow::Tensor > & x)
194 {
195  std::vector< std::pair<std::string, tensorflow::Tensor> > inputs;
196  for(int i=0; i<n_inputs; ++i){
197  inputs.push_back({fInputNames[i], x[i]});
198  }
199 
200  /*
201  // print input/outputs
202  for(int i = 0; i<n_inputs; ++i)
203  std::cout << inputs[i].first << std::endl;
204  for(int i = 0; i<n_outputs; ++i)
205  std::cout << fOutputNames[i] << std::endl;
206  */
207  //std::cout << "run session" << std::endl;
208 
209  std::vector<tensorflow::Tensor> outputs;
210  auto status = fSession->Run(inputs, fOutputNames, {}, &outputs);
211 
212  //std::cout << "out size " << outputs.size() << std::endl;
213 
214  if (status.ok())
215  {
216  size_t samples = 0;
217 
218  for (size_t o = 0; o < outputs.size(); ++o)
219  {
220  if (o == 0) { samples = outputs[o].dim_size(0); }
221  else if ((int)samples != outputs[o].dim_size(0))
222  {
223  throw std::string("TF outputs size inconsistent.");
224  }
225  }
226 
227  std::vector< std::vector< std::vector< float > > > result;
228  result.resize(samples, std::vector< std::vector< float > >(outputs.size()));
229 
230  for (size_t s = 0; s < samples; ++s)
231  {
232  for (size_t o = 0; o < outputs.size(); ++o)
233  {
234  size_t n = outputs[o].dim_size(1);
235  auto output_map = outputs[o].tensor<float, 2>();
236 
237  result[s][o].resize(outputs[o].dim_size(1));
238 
239  std::vector< float > & vs = result[s][o];
240  for (size_t i = 0; i < n; ++i)
241  {
242  vs[i] = output_map(s, i);
243  }
244  }
245  }
246 
247  return result;
248  }
249  else
250  {
251  std::cout << status.ToString() << std::endl;
252  return std::vector< std::vector< std::vector< float > > >();
253  }
254 }
255 // -------------------------------------------------------------------
256 
static QCString name
Definition: declinfo.cpp:673
static QCString result
std::string string
Definition: nybbler.cc:12
struct vector vector
std::vector< float > run(const std::vector< std::vector< float > > &x)
Definition: tf_graph.cc:99
tensorflow::Session * fSession
Definition: tf_graph.h:55
int n_outputs
Definition: tf_graph.h:30
static Config * config
Definition: config.cpp:1054
std::void_t< T > n
static Entry * current
std::vector< std::string > fOutputNames
Definition: tf_graph.h:58
int n_inputs
Definition: tf_graph.h:29
std::vector< std::string > fInputNames
Definition: tf_graph.h:57
Graph(const char *graph_file_name, const std::vector< std::string > &outputs, bool &success, int ninputs, int noutputs)
Not-throwing constructor.
Definition: tf_graph.cc:18
list x
Definition: train.py:276
static QCString * s
Definition: config.cpp:1042
QTextStream & endl(QTextStream &s)