VLNEnergyModel.cxx
Go to the documentation of this file.
1 #include "VLNEnergyModel.h"
2 
4 #include "tensorflow/core/platform/env.h"
5 
6 namespace VLN
7 {
8 
9 const std::vector<InputConfigKeys> VLNEnergyModel::scalarInputKeys({
10  { "input_slice", "vars_slice" }
11 });
12 
13 const std::vector<InputConfigKeys> VLNEnergyModel::vectorInputKeys({
14  { "input_png3d", "vars_png3d" }
15 });
16 
17 const std::vector<std::string> VLNEnergyModel::outputKeys({
18  "target_primary", "target_total"
19 });
20 
22  : model(savedir, scalarInputKeys, vectorInputKeys, outputKeys)
23 { }
24 
26 {
27  std::vector<tensorflow::Tensor> outputs = model.predict(vars);
28 
29  const float primaryE = outputs[0].tensor<float,2>()(0, 0);
30  const float totalE = outputs[1].tensor<float,2>()(0, 0);
31 
32  return VLNEnergy{ primaryE, totalE };
33 }
34 
35 }
static const std::vector< std::string > outputKeys
std::string string
Definition: nybbler.cc:12
VLNEnergy predict(const VarDict &varDict) const
Definition: model.py:1
Definition: VarDict.h:8
Definition: utils.cxx:6
static const std::vector< InputConfigKeys > scalarInputKeys
VLNEnergyModel(const std::string &savedir)
static const std::vector< InputConfigKeys > vectorInputKeys