tf_graph.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 // Class: Graph
3 // Authors: R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, Sept. 2017
4 // P.Plonski, from DUNE, WUT, Sept. 2017
5 //
6 // Iterface to run Tensorflow graph saved to a file. First attempts, quite functional.
7 //
8 ////////////////////////////////////////////////////////////////////////////////////////////////////
9 
10 #include "tf_graph.h"
11 
12 #include "tensorflow/core/public/session.h"
13 #include "tensorflow/core/platform/env.h"
14 
15 #include "tensorflow/core/public/session_options.h"
16 
17 // -------------------------------------------------------------------
18 tf::Graph::Graph(const char* graph_file_name, const std::vector<std::string> & outputs, bool & success)
19 {
20  success = false; // until all is done correctly
21 
22  // Force tf to only use a single core so it doesn't eat batch farms
23  tensorflow::SessionOptions options;
24  tensorflow::ConfigProto &config = options.config;
25  config.set_inter_op_parallelism_threads(1);
26  config.set_intra_op_parallelism_threads(1);
27  config.set_use_per_session_threads(false);
28 
29  auto status = tensorflow::NewSession(options, &fSession);
30  if (!status.ok())
31  {
32  std::cout << status.ToString() << std::endl;
33  return;
34  }
35 
36  tensorflow::GraphDef graph_def;
37  status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
38  if (!status.ok())
39  {
40  std::cout << status.ToString() << std::endl;
41  return;
42  }
43 
44  size_t ng = graph_def.node().size();
45  fInputName = graph_def.node()[0].name();
46 
47  // last node as output if no specific name provided
48  if (outputs.empty()) { fOutputNames.push_back(graph_def.node()[ng - 1].name()); }
49  else // or last nodes with names containing provided strings
50  {
52  for (size_t n = 0; n < ng; ++n)
53  {
54  name = graph_def.node()[n].name();
55  auto pos = name.find("/");
56  if (pos != std::string::npos) { basename = name.substr(0, pos); }
57  else { continue; }
58 
59  bool found = false;
60  for (const auto & s : outputs)
61  {
62  if (name.find(s) != std::string::npos) { found = true; break; }
63  }
64  if (found)
65  {
66  if (!last.empty() && (basename != current))
67  {
68  fOutputNames.push_back(last);
69  }
70  current = basename;
71  last = name;
72  }
73  }
74  if (!last.empty()) { fOutputNames.push_back(last); }
75  }
76  if (fOutputNames.empty())
77  {
78  std::cout << "Output nodes not found in the graph." << std::endl;
79  return;
80  }
81 
82  status = fSession->Create(graph_def);
83  if (!status.ok())
84  {
85  std::cout << status.ToString() << std::endl;
86  return;
87  }
88 
89  success = true; // ok, graph loaded from the file
90 }
91 
93 {
94  fSession->Close();
95  delete fSession;
96 }
97 // -------------------------------------------------------------------
98 
99 std::vector<float> tf::Graph::run(const std::vector< std::vector<float> > & x)
100 {
101  if (x.empty() || x.front().empty()) { return std::vector<float>(); }
102 
103  long long int rows = x.size(), cols = x.front().size();
104 
105  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({ 1, rows, cols, 1 }));
106  auto input_map = _x.tensor<float, 4>();
107 
108  for (long long int r = 0; r < rows; ++r) {
109  const auto & row = x[r];
110  for (long long int c = 0; c < cols; ++c) {
111  input_map(0, r, c, 0) = row[c];
112  }
113  }
114 
115  auto result = run(_x);
116  if (!result.empty()) { return result.front(); }
117  else { return std::vector<float>(); }
118 }
119 // -------------------------------------------------------------------
120 
121 std::vector< std::vector<float> > tf::Graph::run(
122  const std::vector< std::vector< std::vector< std::vector<float> > > > & x,
123  long long int samples)
124 {
125  if ((samples == 0) || x.empty() || x.front().empty() || x.front().front().empty() || x.front().front().front().empty())
126  return std::vector< std::vector<float> >();
127 
128  if ((samples == -1) || (samples > (long long int)x.size())) { samples = x.size(); }
129 
130  long long int
131  rows = x.front().size(),
132  cols = x.front().front().size(),
133  depth = x.front().front().front().size();
134 
135  tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({ samples, rows, cols, depth }));
136  auto input_map = _x.tensor<float, 4>();
137  for (long long int s = 0; s < samples; ++s) {
138  const auto & sample = x[s];
139  for (long long int r = 0; r < rows; ++r) {
140  const auto & row = sample[r];
141  for (long long int c = 0; c < cols; ++c) {
142  const auto & col = row[c];
143  for (long long int d = 0; d < depth; ++d) {
144  input_map(s, r, c, d) = col[d];
145  }
146  }
147  }
148  }
149 
150  return run(_x);
151 }
152 // -------------------------------------------------------------------
153 
154 std::vector< std::vector< float > > tf::Graph::run(const tensorflow::Tensor & x)
155 {
156  std::vector< std::pair<std::string, tensorflow::Tensor> > inputs = {
157  { fInputName, x }
158  };
159 
160  //std::cout << "run session" << std::endl;
161 
162  std::vector<tensorflow::Tensor> outputs;
163  auto status = fSession->Run(inputs, fOutputNames, {}, &outputs);
164 
165  //std::cout << "out size " << outputs.size() << std::endl;
166 
167  if (status.ok())
168  {
169  size_t samples = 0, nouts = 0;
170  for (size_t o = 0; o < outputs.size(); ++o)
171  {
172  if (o == 0) { samples = outputs[o].dim_size(0); }
173  else if ((int)samples != outputs[o].dim_size(0))
174  {
175  throw std::string("TF outputs size inconsistent.");
176  }
177  nouts += outputs[o].dim_size(1);
178  }
179  //std::cout << "samples " << samples << " nouts " << nouts << std::endl;
180 
181  std::vector< std::vector< float > > result;
182  result.resize(samples, std::vector< float >(nouts));
183 
184  size_t idx0 = 0;
185  for (size_t o = 0; o < outputs.size(); ++o)
186  {
187  auto output_map = outputs[o].tensor<float, 2>();
188 
189  size_t n = outputs[o].dim_size(1);
190  for (size_t s = 0; s < samples; ++s) {
191  std::vector< float > & vs = result[s];
192  for (size_t i = 0; i < n; ++i) {
193  vs[idx0 + i] = output_map(s, i);
194  }
195  }
196  idx0 += n;
197  }
198  return result;
199  }
200  else
201  {
202  std::cout << status.ToString() << std::endl;
203  return std::vector< std::vector< float > >();
204  }
205 }
206 // -------------------------------------------------------------------
207 
static QCString name
Definition: declinfo.cpp:673
static QCString result
std::string string
Definition: nybbler.cc:12
std::string fInputName
Definition: tf_graph.h:53
struct vector vector
std::vector< float > run(const std::vector< std::vector< float > > &x)
Definition: tf_graph.cc:99
tensorflow::Session * fSession
Definition: tf_graph.h:55
static Config * config
Definition: config.cpp:1054
std::void_t< T > n
static Entry * current
std::vector< std::string > fOutputNames
Definition: tf_graph.h:58
Graph(const char *graph_file_name, const std::vector< std::string > &outputs, bool &success, int ninputs, int noutputs)
Not-throwing constructor.
Definition: tf_graph.cc:18
list x
Definition: train.py:276
static QCString * s
Definition: config.cpp:1042
QTextStream & endl(QTextStream &s)