tf_graph.h
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 //// Class: Graph
3 //// Authors: R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, Sept. 2017
4 /// P.Plonski, from DUNE, WUT, Sept. 2017
5 ////
6 //// Iterface to run Tensorflow graph saved to a file. First attempts, almost functional.
7 ////
8 ////////////////////////////////////////////////////////////////////////////////////////////////////
9 
10 #ifndef Graph_h
11 #define Graph_h
12 
13 #include <memory>
14 #include <vector>
15 #include <string>
16 
17 namespace tensorflow
18 {
19  class Session;
20  class Tensor;
21 }
22 
23 namespace tf
24 {
25 
26 class Graph
27 {
28 public:
29  static std::unique_ptr<Graph> create(const char* graph_file_name, const std::vector<std::string> & outputs = {})
30  {
31  bool success;
32  std::unique_ptr<Graph> ptr(new Graph(graph_file_name, outputs, success));
33  if (success) { return ptr; }
34  else { return nullptr; }
35  }
36 
37  ~Graph();
38 
39  std::vector<float> run(const std::vector< std::vector<float> > & x);
40 
41  // process vector of 3D inputs, return vector of 1D outputs; use all inputs
42  // if samples = -1, or only the specified number of first samples
43  std::vector< std::vector<float> > run(
44  const std::vector< std::vector< std::vector< std::vector<float> > > > & x,
45  long long int samples = -1);
46  std::vector< std::vector< float > > run(const tensorflow::Tensor & x);
47 
48 private:
49  /// Not-throwing constructor.
50  Graph(const char* graph_file_name, const std::vector<std::string> & outputs, bool & success);
51 
52  tensorflow::Session* fSession;
54  std::vector< std::string > fOutputNames;
55 };
56 
57 } // namespace tf
58 
59 #endif
boost::adjacency_list< boost::vecS, boost::vecS, boost::bidirectionalS, vertex_property, edge_property, graph_property > Graph
Definition: ModuleGraph.h:22
std::string string
Definition: nybbler.cc:12
std::string fInputName
Definition: tf_graph.h:53
struct vector vector
Definition: tf_graph.h:23
static std::unique_ptr< Graph > create(const char *graph_file_name, const std::vector< std::string > &outputs={})
Definition: tf_graph.h:29
list x
Definition: train.py:276