CTPGraph.h
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 //// Class: CTPGraph
3 //// Authors: R.Sulej (Robert.Sulej@cern.ch), from DUNE, FNAL/NCBJ, Sept. 2017
4 /// P.Plonski, from DUNE, WUT, Sept. 2017
5 //// S. Alonso Monsalve, from DUNE, CERN, Aug. 2018
6 //// Iterface to run Tensorflow graph saved to a file. First attempts, almost functional.
7 ////
8 ////////////////////////////////////////////////////////////////////////////////////////////////////
9 
10 #ifndef CTPGraph_h
11 #define CTPGraph_h
12 
13 #include <memory>
14 #include <vector>
15 #include <string>
16 
17 namespace tensorflow
18 {
19  class Session;
20  class Tensor;
21 }
22 
23 namespace tf
24 {
25 
26 class CTPGraph
27 {
28 public:
29  int n_inputs = 1;
30  int n_outputs = 1;
31 
32  static std::unique_ptr<CTPGraph> create(const char* graph_file_name, const std::vector<std::string> & outputs = {}, int ninputs = 1, int noutputs = 1)
33  {
34  bool success;
35  std::unique_ptr<CTPGraph> ptr(new CTPGraph(graph_file_name, outputs, success, ninputs, noutputs));
36  if (success) { return ptr; }
37  else { return nullptr; }
38  }
39 
40  ~CTPGraph();
41 
42  std::vector<float> run(const std::vector< std::vector<float> > & x);
43 
44  // process vector of 3D inputs, return vector of 1D outputs; use all inputs
45  // if samples = -1, or only the specified number of first samples
46  std::vector< std::vector < std::vector< float > > > run(
47  const std::vector< std::vector< std::vector<float> > > & x);
48  std::vector< std::vector < std::vector< float > > > run(const std::vector< tensorflow::Tensor > & x);
49 
50 private:
51  /// Not-throwing constructor.
52  CTPGraph(const char* graph_file_name, const std::vector<std::string> & outputs, bool & success, int ninputs, int noutputs);
53 
54  tensorflow::Session* fSession;
55  //std::vector< std::string > fInputNames;
56  std::vector< std::string > fInputNames;
57  std::vector< std::string > fOutputNames;
58 };
59 
60 } // namespace tf
61 
62 #endif
static std::unique_ptr< CTPGraph > create(const char *graph_file_name, const std::vector< std::string > &outputs={}, int ninputs=1, int noutputs=1)
Definition: CTPGraph.h:32
struct vector vector
Definition: tf_graph.h:23
tensorflow::Session * fSession
Definition: CTPGraph.h:54
std::vector< std::string > fInputNames
Definition: CTPGraph.h:56
std::vector< std::string > fOutputNames
Definition: CTPGraph.h:57
list x
Definition: train.py:276