4 #include "tensorflow/core/platform/env.h" 10 {
"input_slice",
"vars_slice" }
14 {
"input_png3d",
"vars_png3d" }
18 "target_primary",
"target_total" 22 :
model(savedir, scalarInputKeys, vectorInputKeys, outputKeys)
27 std::vector<tensorflow::Tensor>
outputs =
model.predict(vars);
29 const float primaryE = outputs[0].tensor<
float,2>()(0, 0);
30 const float totalE = outputs[1].tensor<
float,2>()(0, 0);
static const std::vector< std::string > outputKeys
VLNEnergy predict(const VarDict &varDict) const
static const std::vector< InputConfigKeys > scalarInputKeys
VLNEnergyModel(const std::string &savedir)
static const std::vector< InputConfigKeys > vectorInputKeys