5 #include <torch/script.h> 6 #include <torch/torch.h> 14 #include "art_root_io/TFileDirectory.h" 15 #include "art_root_io/TFileService.h" 62 std::cout<<
"regcnn_torch job begins ...... "<<
std::endl;
68 std::cerr<<
"error loading the model\n";
79 std::unique_ptr< std::vector<RegCNNResult> >
80 resultCol(
new std::vector<RegCNNResult>);
83 std::vector< art::Ptr< cnn::RegPixelMap3D > > pixelmap3Dlist;
85 auto pixelmap3DListHandle = evt.
getHandle< std::vector< cnn::RegPixelMap3D > >(itag1);
86 if (pixelmap3DListHandle) {
90 if (pixelmap3Dlist.size() > 0) {
91 mf::LogDebug(
"RegCNNPyTorch::produce")<<
"3D pixel map was made for this event, loading it as the input";
101 t_pm = torch::from_blob(pm.
fPECropped.data(), {1,1,32,32,32});
103 t_pm = torch::from_blob(pm.
fPE.data(), {1,1,100,100,100});
106 std::vector<torch::jit::IValue> inputs_pm;
107 inputs_pm.push_back(t_pm);
108 at::Tensor torchOutput =
module.forward(inputs_pm).toTensor();
113 std::vector<float> networkOutput;
114 for (
unsigned int i= 0; i< 3; ++i) {
115 networkOutput.push_back(torchOutput[0][i].item<float>());
116 std::cout<<torchOutput[0][i].item<
float>()<<
std::endl;
119 resultCol->emplace_back(networkOutput);
Handle< PROD > getHandle(SelectorBase const &) const
RegCNNPyTorch(fhicl::ParameterSet const &pset)
EDProducer(fhicl::ParameterSet const &pset)
std::string fPixelMapInput
#define DEFINE_ART_MODULE(klass)
nvidia::inferenceserver::client::Error Error
std::string getenv(std::string const &name)
ProductID put(std::unique_ptr< PROD > &&edp, std::string const &instance={})
RegPixelMap3D for RegCNN modified from PixelMap.h.
Defines an enumeration for cellhit classification.
RegCNNResult for RegCNN modified from Result.h.
void produce(art::Event &evt)
MaybeLogger_< ELseverityLevel::ELsev_success, false > LogDebug
torch::jit::script::Module module
RegPixelMap3D, input to 3D CNN neural net.
auto const & get(AssnsNode< L, R, D > const &r)
void fill_ptr_vector(std::vector< Ptr< T >> &ptrs, H const &h)
std::vector< float > fPECropped
QTextStream & endl(QTextStream &s)
def load(filename, jpath="depos")