RegCNNPyTorch_module.cc
Go to the documentation of this file.
1 // C/C++ includes
2 #include <iostream>
3 #include <sstream>
4 
5 #include <torch/script.h>
6 #include <torch/torch.h>
7 
8 #include "cetlib/getenv.h"
9 
10 // Framework includes
14 #include "art_root_io/TFileDirectory.h"
15 #include "art_root_io/TFileService.h"
17 #include "fhiclcpp/ParameterSet.h"
21 
24 
25 namespace cnn {
26 
27  class RegCNNPyTorch : public art::EDProducer {
28 
29  public:
30  explicit RegCNNPyTorch(fhicl::ParameterSet const& pset);
32 
33  void produce(art::Event& evt);
34  void beginJob();
35  void endJob();
36 
37  private:
38 
43 
45  }; // class RegCNNPyTorch
46 
48  EDProducer(pset),
49  fLibPath (cet::getenv(pset.get<std::string> ("LibPath", ""))),
50  fNetwork (fLibPath + "/" + pset.get<std::string> ("Network")),
51  fPixelMapInput (pset.get<std::string> ("PixelMapInput")),
52  fResultLabel (pset.get<std::string> ("ResultLabel"))
53  {
54  produces<std::vector<cnn::RegCNNResult> >(fResultLabel);
55  }
56 
58 
59  }
60 
62  std::cout<<"regcnn_torch job begins ...... "<<std::endl;
63  try {
64  // Deserialize the ScriptModule from a file using torch::jit::load().
66  }
67  catch (const c10::Error& e) {
68  std::cerr<<"error loading the model\n";
69  return;
70  }
71  mf::LogDebug("RegCNNPyTorch::beginJob")<<"loaded model "<<fNetwork<<" ... ok\n";
72  }
73 
75  }
76 
78  /// Define containers for the things we're going to produce
79  std::unique_ptr< std::vector<RegCNNResult> >
80  resultCol(new std::vector<RegCNNResult>);
81 
82  /// Load 3D pixel map for direction reco.
83  std::vector< art::Ptr< cnn::RegPixelMap3D > > pixelmap3Dlist;
85  auto pixelmap3DListHandle = evt.getHandle< std::vector< cnn::RegPixelMap3D > >(itag1);
86  if (pixelmap3DListHandle) {
87  art::fill_ptr_vector(pixelmap3Dlist, pixelmap3DListHandle);
88  }
89 
90  if (pixelmap3Dlist.size() > 0) {
91  mf::LogDebug("RegCNNPyTorch::produce")<<"3D pixel map was made for this event, loading it as the input";
92 
93  RegPixelMap3D pm = *pixelmap3Dlist[0];
94 
95  // Convert RegPixelMap3D to at::Tensor, which is the actual input of the network
96  // Currently we have two configurations for the 3D pixel map
97  // 100*100*100: a pixel map centered at the vertex, need longer evaluation time
98  // 32*32*32: cropped pixel map around the vertex of the interaction from 100*100*100 pm, faster
99  at::Tensor t_pm;
100  if (pm.IsCroppedPM()) {
101  t_pm = torch::from_blob(pm.fPECropped.data(), {1,1,32,32,32});
102  } else {
103  t_pm = torch::from_blob(pm.fPE.data(), {1,1,100,100,100});
104  }
105 
106  std::vector<torch::jit::IValue> inputs_pm;
107  inputs_pm.push_back(t_pm);
108  at::Tensor torchOutput = module.forward(inputs_pm).toTensor();
109 
110  // Output of the network
111  // Currently only direction reconstruction utilizes 3D CNN, which has 3 output represent 3 components
112  // Absolute value of 3 output are meaningless, their combination is the direction of the prong
113  std::vector<float> networkOutput;
114  for (unsigned int i= 0; i< 3; ++i) {
115  networkOutput.push_back(torchOutput[0][i].item<float>());
116  std::cout<<torchOutput[0][i].item<float>()<<std::endl;
117  }
118 
119  resultCol->emplace_back(networkOutput);
120 
121  //std::cout<<output.slice(/*dim=*/1, /*start=*/0, /*end=*/10) << '\n';
122  //std::cout<<output[0]<<std::endl;
123  }
124 
125  evt.put(std::move(resultCol), fResultLabel);
126  }
127 
129 } // end namespace cnn
Handle< PROD > getHandle(SelectorBase const &) const
Definition: DataViewImpl.h:382
std::string string
Definition: nybbler.cc:12
RegCNNPyTorch(fhicl::ParameterSet const &pset)
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.h:20
STL namespace.
std::vector< float > fPE
Definition: RegPixelMap3D.h:48
const double e
#define DEFINE_ART_MODULE(klass)
Definition: ModuleMacros.h:67
#define Module
nvidia::inferenceserver::client::Error Error
Definition: triton_utils.h:15
std::string getenv(std::string const &name)
Definition: getenv.cc:15
def move(depos, offset)
Definition: depos.py:107
ProductID put(std::unique_ptr< PROD > &&edp, std::string const &instance={})
Definition: DataViewImpl.h:686
RegPixelMap3D for RegCNN modified from PixelMap.h.
Defines an enumeration for cellhit classification.
RegCNNResult for RegCNN modified from Result.h.
void produce(art::Event &evt)
MaybeLogger_< ELseverityLevel::ELsev_success, false > LogDebug
torch::jit::script::Module module
RegPixelMap3D, input to 3D CNN neural net.
Definition: RegPixelMap3D.h:22
TCEvent evt
Definition: DataStructs.cxx:7
auto const & get(AssnsNode< L, R, D > const &r)
Definition: AssnsNode.h:115
void fill_ptr_vector(std::vector< Ptr< T >> &ptrs, H const &h)
Definition: Ptr.h:297
std::vector< float > fPECropped
Definition: RegPixelMap3D.h:49
bool IsCroppedPM() const
Definition: RegPixelMap3D.h:29
QTextStream & endl(QTextStream &s)
def load(filename, jpath="depos")
Definition: depos.py:34