Public Member Functions | Private Member Functions | Static Private Member Functions | Private Attributes | List of all members
TFModel Class Reference

#include <TFModel.h>

Public Member Functions

 TFModel (const std::string &savedir, const std::vector< InputConfigKeys > &scalarInputKeys, const std::vector< InputConfigKeys > &vectorInputKeys, const std::vector< std::string > &outputKeys)
 
void ensure_initialized () const
 
std::vector< tensorflow::Tensor > predict (const VarDict &vars) const
 

Private Member Functions

void initTFSession () const
 

Static Private Member Functions

static tensorflow::Tensor constructDummyVectorInput (const std::vector< std::string > &vars, float fillValue=0.0)
 
static tensorflow::Tensor constructScalarInput (const std::unordered_map< std::string, double > &varMap, const std::vector< std::string > &vars)
 
static tensorflow::Tensor constructVectorInput (const std::unordered_map< std::string, std::vector< double >> &varMap, const std::vector< std::string > &vars)
 

Private Attributes

ModelConfig config
 
std::shared_ptr< tensorflow::Session > tfSession
 
bool initialized
 

Detailed Description

Definition at line 13 of file TFModel.h.

Constructor & Destructor Documentation

TFModel::TFModel ( const std::string savedir,
const std::vector< InputConfigKeys > &  scalarInputKeys,
const std::vector< InputConfigKeys > &  vectorInputKeys,
const std::vector< std::string > &  outputKeys 
)

Definition at line 17 of file TFModel.cxx.

22  : config(savedir, scalarInputKeys, vectorInputKeys, outputKeys),
23  tfSession(nullptr),
24  initialized(false)
25 { }
ModelConfig config
Definition: TFModel.h:44
bool initialized
Definition: TFModel.h:46
std::shared_ptr< tensorflow::Session > tfSession
Definition: TFModel.h:45

Member Function Documentation

Tensor TFModel::constructDummyVectorInput ( const std::vector< std::string > &  vars,
float  fillValue = 0.0 
)
staticprivate

Definition at line 27 of file TFModel.cxx.

30 {
31  Tensor result(
32  DT_FLOAT, TensorShape( {1, 1, asInt(vars.size())} )
33  );
34 
35  auto resultData = result.tensor<float, 3>();
36 
37  for (int varIdx = 0; varIdx < asInt(vars.size()); varIdx++) {
38  resultData(0, 0, varIdx) = fillValue;
39  }
40 
41  return result;
42 }
static QCString result
int asInt(T x)
Definition: TFModel.cxx:12
Tensor TFModel::constructScalarInput ( const std::unordered_map< std::string, double > &  varMap,
const std::vector< std::string > &  vars 
)
staticprivate

Definition at line 44 of file TFModel.cxx.

48 {
49  Tensor result(
50  DT_FLOAT, TensorShape( {1, asInt(vars.size())} )
51  );
52  auto resultData = result.tensor<float, 2>();
53 
54  for (int varIdx = 0; varIdx < asInt(vars.size()); varIdx++) {
55  resultData(0, varIdx) = varMap.at(vars[varIdx]);
56  }
57 
58  return result;
59 }
static QCString result
int asInt(T x)
Definition: TFModel.cxx:12
tensorflow::Tensor TFModel::constructVectorInput ( const std::unordered_map< std::string, std::vector< double >> &  varMap,
const std::vector< std::string > &  vars 
)
staticprivate

Definition at line 61 of file TFModel.cxx.

65 {
66  if (vars.empty()) {
67  return constructDummyVectorInput(vars, 0.0);
68  }
69 
70  const size_t vectorSize = (varMap.empty()) ? 0 : varMap.at(vars[0]).size();
71 
72  if (vectorSize == 0) {
73  /*
74  * NOTE: Fake tensor with vectorSize == 1 is needed, since otherwise
75  * tensorflow fails to infer graph dimensions.
76  */
77  return constructDummyVectorInput(vars, 0.0);
78  }
79 
80  Tensor result(
81  DT_FLOAT, TensorShape({ 1, asInt(vectorSize), asInt(vars.size()) })
82  );
83  auto resultData = result.tensor<float, 3>();
84 
85  for (int varIdx = 0; varIdx < asInt(vars.size()); varIdx++)
86  {
87  const auto &values = varMap.at(vars[varIdx]);
88 
89  if (values.size() != vectorSize) {
90  throw std::runtime_error("Vectors have different lengths");
91  }
92 
93  for (int i = 0; i < asInt(vectorSize); i++) {
94  resultData(0, i, varIdx) = values[i];
95  }
96  }
97 
98  return result;
99 }
static QCString result
Q_UINT16 values[128]
static tensorflow::Tensor constructDummyVectorInput(const std::vector< std::string > &vars, float fillValue=0.0)
Definition: TFModel.cxx:27
int asInt(T x)
Definition: TFModel.cxx:12
void TFModel::ensure_initialized ( ) const

Definition at line 135 of file TFModel.cxx.

136 {
137  if (initialized) {
138  return;
139  }
140 
141  config.load();
142  initTFSession();
143 
144  initialized = true;
145 }
void initTFSession() const
Definition: TFModel.cxx:101
ModelConfig config
Definition: TFModel.h:44
bool initialized
Definition: TFModel.h:46
void TFModel::initTFSession ( ) const
private

Definition at line 101 of file TFModel.cxx.

102 {
103  GraphDef graph;
104  Session *session = nullptr;
105 
106  /* TODO: Find NewSession version with safer memory management */
107  auto status = NewSession(SessionOptions(), &session);
108 
109  if (! status.ok())
110  {
111  delete session;
112  throw std::runtime_error(
113  "Failed to initialize TF Session: " + status.ToString()
114  );
115  }
116 
117  tfSession.reset(session);
118  session = nullptr;
119 
120  status = ReadBinaryProto(Env::Default(), config.getModelPath(), &graph);
121  if (! status.ok()) {
122  throw std::runtime_error(
123  "Failed to load TF Graph: " + status.ToString()
124  );
125  }
126 
127  status = tfSession->Create(graph);
128  if (! status.ok()) {
129  throw std::runtime_error(
130  "Failed to create TF Session: " + status.ToString()
131  );
132  }
133 }
def graph(desc, maker=maker)
Definition: apa.py:294
std::string getModelPath() const
ModelConfig config
Definition: TFModel.h:44
def session(dburl="sqlite:///:memory:")
Definition: db.py:754
std::shared_ptr< tensorflow::Session > tfSession
Definition: TFModel.h:45
std::vector< Tensor > TFModel::predict ( const VarDict vars) const

Definition at line 147 of file TFModel.cxx.

148 {
150 
151  std::vector<std::pair<std::string, Tensor>> inputs;
152  std::vector<Tensor> outputs;
153 
154  inputs.reserve(
155  config.getScalarInputs().size() + config.getVectorInputs().size()
156  );
157 
158  for (const auto &inputConfig : config.getScalarInputs()) {
159  inputs.emplace_back(
160  inputConfig.nodeName,
161  constructScalarInput(vars.scalar, inputConfig.varNames)
162  );
163  }
164 
165  for (const auto &inputConfig : config.getVectorInputs()) {
166  inputs.emplace_back(
167  inputConfig.nodeName,
168  constructVectorInput(vars.vector, inputConfig.varNames)
169  );
170  }
171 
172  auto status = tfSession->Run(
173  inputs, config.getOutputNodes(), {}, &outputs
174  );
175 
176  if (! status.ok()) {
177  throw std::runtime_error(
178  "Failed to run TF Session: " + status.ToString()
179  );
180  }
181 
182  return outputs;
183 }
static tensorflow::Tensor constructScalarInput(const std::unordered_map< std::string, double > &varMap, const std::vector< std::string > &vars)
Definition: TFModel.cxx:44
ModelConfig config
Definition: TFModel.h:44
void ensure_initialized() const
Definition: TFModel.cxx:135
std::unordered_map< std::string, std::vector< double > > vector
Definition: VarDict.h:11
const std::vector< InputConfig > & getVectorInputs() const
const std::vector< std::string > & getOutputNodes() const
static tensorflow::Tensor constructVectorInput(const std::unordered_map< std::string, std::vector< double >> &varMap, const std::vector< std::string > &vars)
Definition: TFModel.cxx:61
const std::vector< InputConfig > & getScalarInputs() const
std::shared_ptr< tensorflow::Session > tfSession
Definition: TFModel.h:45
std::unordered_map< std::string, double > scalar
Definition: VarDict.h:10

Member Data Documentation

ModelConfig TFModel::config
mutableprivate

Definition at line 44 of file TFModel.h.

bool TFModel::initialized
mutableprivate

Definition at line 46 of file TFModel.h.

std::shared_ptr<tensorflow::Session> TFModel::tfSession
mutableprivate

Definition at line 45 of file TFModel.h.


The documentation for this class was generated from the following files: