tf_graph.h
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 //// Class: Graph
3 //// Authors: R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, Sept. 2017
4 /// P.Plonski, from DUNE, WUT, Sept. 2017
5 //// S. Alonso Monsalve, from DUNE, CERN, Aug. 2018
6 //// Iterface to run Tensorflow graph saved to a file. First attempts, almost functional.
7 ////
8 ////////////////////////////////////////////////////////////////////////////////////////////////////
9 
10 #ifndef Graph_h
11 #define Graph_h
12 
13 #include <memory>
14 #include <vector>
15 #include <string>
16 
17 namespace tensorflow
18 {
19  class Session;
20  class Tensor;
21 }
22 
23 namespace tf
24 {
25 
26 class Graph
27 {
28 public:
29  int n_inputs = 1;
30  int n_outputs = 1;
31 
32  static std::unique_ptr<Graph> create(const char* graph_file_name, const std::vector<std::string> & outputs = {}, int ninputs = 1, int noutputs = 1)
33  {
34  bool success;
35  std::unique_ptr<Graph> ptr(new Graph(graph_file_name, outputs, success, ninputs, noutputs));
36  if (success) { return ptr; }
37  else { return nullptr; }
38  }
39 
40  ~Graph();
41 
42  std::vector<float> run(const std::vector< std::vector<float> > & x);
43 
44  // process vector of 3D inputs, return vector of 1D outputs; use all inputs
45  // if samples = -1, or only the specified number of first samples
46  std::vector< std::vector < std::vector< float > > > run(
47  const std::vector< std::vector< std::vector< std::vector<float> > > > & x,
48  long long int samples = -1);
49  std::vector< std::vector < std::vector< float > > > run(const std::vector< tensorflow::Tensor > & x);
50 
51 private:
52  /// Not-throwing constructor.
53  Graph(const char* graph_file_name, const std::vector<std::string> & outputs, bool & success, int ninputs, int noutputs);
54 
55  tensorflow::Session* fSession;
56  //std::vector< std::string > fInputNames;
57  std::vector< std::string > fInputNames;
58  std::vector< std::string > fOutputNames;
59 };
60 
61 } // namespace tf
62 
63 #endif
boost::adjacency_list< boost::vecS, boost::vecS, boost::bidirectionalS, vertex_property, edge_property, graph_property > Graph
Definition: ModuleGraph.h:22
struct vector vector
Definition: tf_graph.h:23
tensorflow::Session * fSession
Definition: tf_graph.h:55
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, int ninputs=1, int noutputs=1)
Definition: tf_graph.h:32
std::vector< std::string > fOutputNames
Definition: tf_graph.h:58
std::vector< std::string > fInputNames
Definition: tf_graph.h:57
list x
Definition: train.py:276