LArDLHelper.h
Go to the documentation of this file.
1 /**
2  * @file larpandoradlcontent/LArHelpers/LArDLHelper.h
3  *
4  * @brief Header file for the lar deep learning helper helper class.
5  *
6  * $Log: $
7  */
8 #ifndef LAR_DL_HELPER_H
9 #define LAR_DL_HELPER_H 1
10 
11 #include <torch/script.h>
12 #include <torch/torch.h>
13 
14 #include "Pandora/StatusCodes.h"
15 
16 namespace lar_dl_content
17 {
18 
19 /**
20  * @brief LArDLHelper class
21  */
23 {
24 public:
26  typedef torch::Tensor TorchInput;
27  typedef std::vector<torch::jit::IValue> TorchInputVector;
28  typedef at::Tensor TorchOutput;
29 
30  /**
31  * @brief Loads a deep learning model
32  *
33  * @param filename the filename of the model to load
34  * @param model the TorchModel in which to store the loaded model
35  *
36  * @return STATUS_CODE_SUCCESS upon successful loading of the model. STATUS_CODE_FAILURE otherwise.
37  */
38  static pandora::StatusCode LoadModel(const std::string &filename, TorchModel &model);
39 
40  /**
41  * @brief Create a torch input tensor
42  *
43  * @param dimensions the size of each dimension of the tensor: pass as {a, b, c, d} for example
44  * @param tensor the tensor to be initialised
45  */
46  static void InitialiseInput(const at::IntArrayRef dimensions, TorchInput &tensor);
47 
48  /**
49  * @brief Run a deep learning model
50  *
51  * @param model the model to run
52  * @param input the input to run over
53  * @param output the tensor to store the output in
54  */
55  static void Forward(TorchModel &model, const TorchInputVector &input, TorchOutput &output);
56 };
57 
58 } // namespace lar_dl_content
59 
60 #endif // #ifndef LAR_DL_HELPER_H
std::string string
Definition: nybbler.cc:12
Definition: model.py:1
torch::jit::script::Module TorchModel
Definition: LArDLHelper.h:25
string filename
Definition: train.py:213
static int input(void)
Definition: code.cpp:15695
#define Module
static void Forward(TorchModel &model, const TorchInputVector &input, TorchOutput &output)
Run a deep learning model.
Definition: LArDLHelper.cc:41
LArDLHelper class.
Definition: LArDLHelper.h:22
static pandora::StatusCode LoadModel(const std::string &filename, TorchModel &model)
Loads a deep learning model.
Definition: LArDLHelper.cc:16
static void InitialiseInput(const at::IntArrayRef dimensions, TorchInput &tensor)
Create a torch input tensor.
Definition: LArDLHelper.cc:34
std::vector< torch::jit::IValue > TorchInputVector
Definition: LArDLHelper.h:27