TFModel.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <memory>
4 #include <string>
5 #include <utility>
6 #include <vector>
7 
9 #include "ModelConfig.h"
10 
11 namespace tensorflow { class Session; class Tensor; }
12 
13 class TFModel
14 {
15 public:
16  TFModel(
17  const std::string &savedir,
18  const std::vector<InputConfigKeys> &scalarInputKeys,
19  const std::vector<InputConfigKeys> &vectorInputKeys,
20  const std::vector<std::string> &outputKeys
21  );
22 
23  void ensure_initialized() const;
24  std::vector<tensorflow::Tensor> predict(const VarDict &vars) const;
25 
26 private:
27  static tensorflow::Tensor constructDummyVectorInput(
28  const std::vector<std::string> &vars, float fillValue = 0.0
29  );
30 
31  static tensorflow::Tensor constructScalarInput(
32  const std::unordered_map<std::string, double> &varMap,
33  const std::vector<std::string> &vars
34  );
35 
36  static tensorflow::Tensor constructVectorInput(
37  const std::unordered_map<std::string, std::vector<double>> &varMap,
38  const std::vector<std::string> &vars
39  );
40 
41  void initTFSession() const;
42 
43 private:
45  mutable std::shared_ptr<tensorflow::Session> tfSession;
46  mutable bool initialized;
47 };
48 
std::string string
Definition: nybbler.cc:12
ModelConfig config
Definition: TFModel.h:44
Definition: VarDict.h:8
def predict(model, test_dir, N, trace, info)
Definition: predict.py:19
bool initialized
Definition: TFModel.h:46
static QMap< QCString, MemberDef * > varMap
Definition: vhdldocgen.cpp:713
std::shared_ptr< tensorflow::Session > tfSession
Definition: TFModel.h:45