Public Member Functions | Private Attributes | List of all members
cnn::RegCNNPyTorch Class Reference
Inheritance diagram for cnn::RegCNNPyTorch:
art::EDProducer art::detail::Producer art::detail::LegacyModule art::Modifier art::ModuleBase art::ProductRegistryHelper

Public Member Functions

 RegCNNPyTorch (fhicl::ParameterSet const &pset)
 
 ~RegCNNPyTorch ()
 
void produce (art::Event &evt)
 
void beginJob ()
 
void endJob ()
 
- Public Member Functions inherited from art::EDProducer
 EDProducer (fhicl::ParameterSet const &pset)
 
template<typename Config >
 EDProducer (Table< Config > const &config)
 
std::string workerType () const
 
- Public Member Functions inherited from art::detail::Producer
virtual ~Producer () noexcept
 
 Producer (fhicl::ParameterSet const &)
 
 Producer (Producer const &)=delete
 
 Producer (Producer &&)=delete
 
Produceroperator= (Producer const &)=delete
 
Produceroperator= (Producer &&)=delete
 
void doBeginJob (SharedResources const &resources)
 
void doEndJob ()
 
void doRespondToOpenInputFile (FileBlock const &fb)
 
void doRespondToCloseInputFile (FileBlock const &fb)
 
void doRespondToOpenOutputFiles (FileBlock const &fb)
 
void doRespondToCloseOutputFiles (FileBlock const &fb)
 
bool doBeginRun (RunPrincipal &rp, ModuleContext const &mc)
 
bool doEndRun (RunPrincipal &rp, ModuleContext const &mc)
 
bool doBeginSubRun (SubRunPrincipal &srp, ModuleContext const &mc)
 
bool doEndSubRun (SubRunPrincipal &srp, ModuleContext const &mc)
 
bool doEvent (EventPrincipal &ep, ModuleContext const &mc, std::atomic< std::size_t > &counts_run, std::atomic< std::size_t > &counts_passed, std::atomic< std::size_t > &counts_failed)
 
- Public Member Functions inherited from art::Modifier
 ~Modifier () noexcept
 
 Modifier ()
 
 Modifier (Modifier const &)=delete
 
 Modifier (Modifier &&)=delete
 
Modifieroperator= (Modifier const &)=delete
 
Modifieroperator= (Modifier &&)=delete
 
- Public Member Functions inherited from art::ModuleBase
virtual ~ModuleBase () noexcept
 
 ModuleBase ()
 
ModuleDescription const & moduleDescription () const
 
void setModuleDescription (ModuleDescription const &)
 
std::array< std::vector< ProductInfo >, NumBranchTypes > const & getConsumables () const
 
void sortConsumables (std::string const &current_process_name)
 
template<typename T , BranchType BT>
ViewToken< T > consumesView (InputTag const &tag)
 
template<typename T , BranchType BT>
ViewToken< T > mayConsumeView (InputTag const &tag)
 

Private Attributes

std::string fLibPath
 
std::string fNetwork
 
std::string fPixelMapInput
 
std::string fResultLabel
 
torch::jit::script::Module module
 

Additional Inherited Members

- Public Types inherited from art::EDProducer
using ModuleType = EDProducer
 
using WorkerType = WorkerT< EDProducer >
 
- Public Types inherited from art::detail::Producer
template<typename UserConfig , typename KeysToIgnore = void>
using Table = Modifier::Table< UserConfig, KeysToIgnore >
 
- Public Types inherited from art::Modifier
template<typename UserConfig , typename UserKeysToIgnore = void>
using Table = ProducerTable< UserConfig, detail::ModuleConfig, UserKeysToIgnore >
 
- Static Public Member Functions inherited from art::EDProducer
static void commitEvent (EventPrincipal &ep, Event &e)
 
- Protected Member Functions inherited from art::ModuleBase
ConsumesCollectorconsumesCollector ()
 
template<typename T , BranchType = InEvent>
ProductToken< T > consumes (InputTag const &)
 
template<typename Element , BranchType = InEvent>
ViewToken< Element > consumesView (InputTag const &)
 
template<typename T , BranchType = InEvent>
void consumesMany ()
 
template<typename T , BranchType = InEvent>
ProductToken< T > mayConsume (InputTag const &)
 
template<typename Element , BranchType = InEvent>
ViewToken< Element > mayConsumeView (InputTag const &)
 
template<typename T , BranchType = InEvent>
void mayConsumeMany ()
 

Detailed Description

Definition at line 27 of file RegCNNPyTorch_module.cc.

Constructor & Destructor Documentation

cnn::RegCNNPyTorch::RegCNNPyTorch ( fhicl::ParameterSet const &  pset)
explicit

Definition at line 47 of file RegCNNPyTorch_module.cc.

47  :
48  EDProducer(pset),
49  fLibPath (cet::getenv(pset.get<std::string> ("LibPath", ""))),
50  fNetwork (fLibPath + "/" + pset.get<std::string> ("Network")),
51  fPixelMapInput (pset.get<std::string> ("PixelMapInput")),
52  fResultLabel (pset.get<std::string> ("ResultLabel"))
53  {
54  produces<std::vector<cnn::RegCNNResult> >(fResultLabel);
55  }
std::string string
Definition: nybbler.cc:12
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.h:20
std::string getenv(std::string const &name)
Definition: getenv.cc:15
cnn::RegCNNPyTorch::~RegCNNPyTorch ( )

Definition at line 57 of file RegCNNPyTorch_module.cc.

57  {
58 
59  }

Member Function Documentation

void cnn::RegCNNPyTorch::beginJob ( )
virtual

Reimplemented from art::EDProducer.

Definition at line 61 of file RegCNNPyTorch_module.cc.

61  {
62  std::cout<<"regcnn_torch job begins ...... "<<std::endl;
63  try {
64  // Deserialize the ScriptModule from a file using torch::jit::load().
66  }
67  catch (const c10::Error& e) {
68  std::cerr<<"error loading the model\n";
69  return;
70  }
71  mf::LogDebug("RegCNNPyTorch::beginJob")<<"loaded model "<<fNetwork<<" ... ok\n";
72  }
const double e
nvidia::inferenceserver::client::Error Error
Definition: triton_utils.h:15
MaybeLogger_< ELseverityLevel::ELsev_success, false > LogDebug
torch::jit::script::Module module
QTextStream & endl(QTextStream &s)
def load(filename, jpath="depos")
Definition: depos.py:34
void cnn::RegCNNPyTorch::endJob ( )
virtual

Reimplemented from art::EDProducer.

Definition at line 74 of file RegCNNPyTorch_module.cc.

74  {
75  }
void cnn::RegCNNPyTorch::produce ( art::Event evt)
virtual

Define containers for the things we're going to produce

Load 3D pixel map for direction reco.

Implements art::EDProducer.

Definition at line 77 of file RegCNNPyTorch_module.cc.

77  {
78  /// Define containers for the things we're going to produce
79  std::unique_ptr< std::vector<RegCNNResult> >
80  resultCol(new std::vector<RegCNNResult>);
81 
82  /// Load 3D pixel map for direction reco.
83  std::vector< art::Ptr< cnn::RegPixelMap3D > > pixelmap3Dlist;
85  auto pixelmap3DListHandle = evt.getHandle< std::vector< cnn::RegPixelMap3D > >(itag1);
86  if (pixelmap3DListHandle) {
87  art::fill_ptr_vector(pixelmap3Dlist, pixelmap3DListHandle);
88  }
89 
90  if (pixelmap3Dlist.size() > 0) {
91  mf::LogDebug("RegCNNPyTorch::produce")<<"3D pixel map was made for this event, loading it as the input";
92 
93  RegPixelMap3D pm = *pixelmap3Dlist[0];
94 
95  // Convert RegPixelMap3D to at::Tensor, which is the actual input of the network
96  // Currently we have two configurations for the 3D pixel map
97  // 100*100*100: a pixel map centered at the vertex, need longer evaluation time
98  // 32*32*32: cropped pixel map around the vertex of the interaction from 100*100*100 pm, faster
99  at::Tensor t_pm;
100  if (pm.IsCroppedPM()) {
101  t_pm = torch::from_blob(pm.fPECropped.data(), {1,1,32,32,32});
102  } else {
103  t_pm = torch::from_blob(pm.fPE.data(), {1,1,100,100,100});
104  }
105 
106  std::vector<torch::jit::IValue> inputs_pm;
107  inputs_pm.push_back(t_pm);
108  at::Tensor torchOutput = module.forward(inputs_pm).toTensor();
109 
110  // Output of the network
111  // Currently only direction reconstruction utilizes 3D CNN, which has 3 output represent 3 components
112  // Absolute value of 3 output are meaningless, their combination is the direction of the prong
113  std::vector<float> networkOutput;
114  for (unsigned int i= 0; i< 3; ++i) {
115  networkOutput.push_back(torchOutput[0][i].item<float>());
116  std::cout<<torchOutput[0][i].item<float>()<<std::endl;
117  }
118 
119  resultCol->emplace_back(networkOutput);
120 
121  //std::cout<<output.slice(/*dim=*/1, /*start=*/0, /*end=*/10) << '\n';
122  //std::cout<<output[0]<<std::endl;
123  }
124 
125  evt.put(std::move(resultCol), fResultLabel);
126  }
Handle< PROD > getHandle(SelectorBase const &) const
Definition: DataViewImpl.h:382
def move(depos, offset)
Definition: depos.py:107
ProductID put(std::unique_ptr< PROD > &&edp, std::string const &instance={})
Definition: DataViewImpl.h:686
MaybeLogger_< ELseverityLevel::ELsev_success, false > LogDebug
torch::jit::script::Module module
void fill_ptr_vector(std::vector< Ptr< T >> &ptrs, H const &h)
Definition: Ptr.h:297
QTextStream & endl(QTextStream &s)

Member Data Documentation

std::string cnn::RegCNNPyTorch::fLibPath
private

Definition at line 39 of file RegCNNPyTorch_module.cc.

std::string cnn::RegCNNPyTorch::fNetwork
private

Definition at line 40 of file RegCNNPyTorch_module.cc.

std::string cnn::RegCNNPyTorch::fPixelMapInput
private

Definition at line 41 of file RegCNNPyTorch_module.cc.

std::string cnn::RegCNNPyTorch::fResultLabel
private

Definition at line 42 of file RegCNNPyTorch_module.cc.

torch::jit::script::Module cnn::RegCNNPyTorch::module
private

Definition at line 44 of file RegCNNPyTorch_module.cc.


The documentation for this class was generated from the following file: