TFModel.cxx
Go to the documentation of this file.
1 #include "TFModel.h"
2 
3 #include <utility>
4 
5 #include <boost/numeric/conversion/cast.hpp>
7 #include "tensorflow/core/platform/env.h"
8 
9 using namespace tensorflow;
10 
11 template<typename T>
12 inline int asInt(T x)
13 {
14  return boost::numeric_cast<int>(x);
15 }
16 
18  const std::string &savedir,
19  const std::vector<InputConfigKeys> &scalarInputKeys,
20  const std::vector<InputConfigKeys> &vectorInputKeys,
21  const std::vector<std::string> &outputKeys
22 ) : config(savedir, scalarInputKeys, vectorInputKeys, outputKeys),
23  tfSession(nullptr),
24  initialized(false)
25 { }
26 
28  const std::vector<std::string> &vars, float fillValue
29 )
30 {
31  Tensor result(
32  DT_FLOAT, TensorShape( {1, 1, asInt(vars.size())} )
33  );
34 
35  auto resultData = result.tensor<float, 3>();
36 
37  for (int varIdx = 0; varIdx < asInt(vars.size()); varIdx++) {
38  resultData(0, 0, varIdx) = fillValue;
39  }
40 
41  return result;
42 }
43 
45  const std::unordered_map<std::string, double> &varMap,
46  const std::vector<std::string> &vars
47 )
48 {
49  Tensor result(
50  DT_FLOAT, TensorShape( {1, asInt(vars.size())} )
51  );
52  auto resultData = result.tensor<float, 2>();
53 
54  for (int varIdx = 0; varIdx < asInt(vars.size()); varIdx++) {
55  resultData(0, varIdx) = varMap.at(vars[varIdx]);
56  }
57 
58  return result;
59 }
60 
61 tensorflow::Tensor TFModel::constructVectorInput(
62  const std::unordered_map<std::string, std::vector<double>> &varMap,
63  const std::vector<std::string> &vars
64 )
65 {
66  if (vars.empty()) {
67  return constructDummyVectorInput(vars, 0.0);
68  }
69 
70  const size_t vectorSize = (varMap.empty()) ? 0 : varMap.at(vars[0]).size();
71 
72  if (vectorSize == 0) {
73  /*
74  * NOTE: Fake tensor with vectorSize == 1 is needed, since otherwise
75  * tensorflow fails to infer graph dimensions.
76  */
77  return constructDummyVectorInput(vars, 0.0);
78  }
79 
80  Tensor result(
81  DT_FLOAT, TensorShape({ 1, asInt(vectorSize), asInt(vars.size()) })
82  );
83  auto resultData = result.tensor<float, 3>();
84 
85  for (int varIdx = 0; varIdx < asInt(vars.size()); varIdx++)
86  {
87  const auto &values = varMap.at(vars[varIdx]);
88 
89  if (values.size() != vectorSize) {
90  throw std::runtime_error("Vectors have different lengths");
91  }
92 
93  for (int i = 0; i < asInt(vectorSize); i++) {
94  resultData(0, i, varIdx) = values[i];
95  }
96  }
97 
98  return result;
99 }
100 
102 {
103  GraphDef graph;
104  Session *session = nullptr;
105 
106  /* TODO: Find NewSession version with safer memory management */
107  auto status = NewSession(SessionOptions(), &session);
108 
109  if (! status.ok())
110  {
111  delete session;
112  throw std::runtime_error(
113  "Failed to initialize TF Session: " + status.ToString()
114  );
115  }
116 
117  tfSession.reset(session);
118  session = nullptr;
119 
120  status = ReadBinaryProto(Env::Default(), config.getModelPath(), &graph);
121  if (! status.ok()) {
122  throw std::runtime_error(
123  "Failed to load TF Graph: " + status.ToString()
124  );
125  }
126 
127  status = tfSession->Create(graph);
128  if (! status.ok()) {
129  throw std::runtime_error(
130  "Failed to create TF Session: " + status.ToString()
131  );
132  }
133 }
134 
136 {
137  if (initialized) {
138  return;
139  }
140 
141  config.load();
142  initTFSession();
143 
144  initialized = true;
145 }
146 
147 std::vector<Tensor> TFModel::predict(const VarDict &vars) const
148 {
150 
151  std::vector<std::pair<std::string, Tensor>> inputs;
152  std::vector<Tensor> outputs;
153 
154  inputs.reserve(
155  config.getScalarInputs().size() + config.getVectorInputs().size()
156  );
157 
158  for (const auto &inputConfig : config.getScalarInputs()) {
159  inputs.emplace_back(
160  inputConfig.nodeName,
161  constructScalarInput(vars.scalar, inputConfig.varNames)
162  );
163  }
164 
165  for (const auto &inputConfig : config.getVectorInputs()) {
166  inputs.emplace_back(
167  inputConfig.nodeName,
168  constructVectorInput(vars.vector, inputConfig.varNames)
169  );
170  }
171 
172  auto status = tfSession->Run(
173  inputs, config.getOutputNodes(), {}, &outputs
174  );
175 
176  if (! status.ok()) {
177  throw std::runtime_error(
178  "Failed to run TF Session: " + status.ToString()
179  );
180  }
181 
182  return outputs;
183 }
184 
static QCString result
static tensorflow::Tensor constructScalarInput(const std::unordered_map< std::string, double > &varMap, const std::vector< std::string > &vars)
Definition: TFModel.cxx:44
std::string string
Definition: nybbler.cc:12
def graph(desc, maker=maker)
Definition: apa.py:294
std::vector< tensorflow::Tensor > predict(const VarDict &vars) const
Definition: TFModel.cxx:147
void initTFSession() const
Definition: TFModel.cxx:101
std::string getModelPath() const
ModelConfig config
Definition: TFModel.h:44
Definition: VarDict.h:8
void ensure_initialized() const
Definition: TFModel.cxx:135
std::unordered_map< std::string, std::vector< double > > vector
Definition: VarDict.h:11
static Config * config
Definition: config.cpp:1054
TFModel(const std::string &savedir, const std::vector< InputConfigKeys > &scalarInputKeys, const std::vector< InputConfigKeys > &vectorInputKeys, const std::vector< std::string > &outputKeys)
Definition: TFModel.cxx:17
Q_UINT16 values[128]
def session(dburl="sqlite:///:memory:")
Definition: db.py:754
const std::vector< InputConfig > & getVectorInputs() const
static tensorflow::Tensor constructDummyVectorInput(const std::vector< std::string > &vars, float fillValue=0.0)
Definition: TFModel.cxx:27
const std::vector< std::string > & getOutputNodes() const
static tensorflow::Tensor constructVectorInput(const std::unordered_map< std::string, std::vector< double >> &varMap, const std::vector< std::string > &vars)
Definition: TFModel.cxx:61
list x
Definition: train.py:276
bool initialized
Definition: TFModel.h:46
const std::vector< InputConfig > & getScalarInputs() const
int asInt(T x)
Definition: TFModel.cxx:12
static QMap< QCString, MemberDef * > varMap
Definition: vhdldocgen.cpp:713
std::shared_ptr< tensorflow::Session > tfSession
Definition: TFModel.h:45
std::unordered_map< std::string, double > scalar
Definition: VarDict.h:10