Public Member Functions | Private Attributes | List of all members
cnn::TFRegNetHandler Class Reference

Wrapper for caffe::Net which handles construction and prediction. More...

#include <TFRegNetHandler.h>

Public Member Functions

 TFRegNetHandler (const fhicl::ParameterSet &pset)
 Constructor which takes a pset with DeployProto and ModelFile fields. More...
 
std::vector< float > Predict (const RegPixelMap &pm)
 Return prediction arrays for RegPixelMap. More...
 
std::vector< float > Predict (const RegPixelMap &pm, const std::vector< float > cm_list)
 
std::vector< float > PredictNuEEnergy (const RegPixelMap &pm)
 

Private Attributes

std::string fLibPath
 Library path (typically dune_pardata...) More...
 
std::string fTFProtoBuf
 location of the tf .pb file in the above path More...
 
unsigned int fInputs
 Number of tdcs for the network to classify. More...
 
std::vector< std::stringfOutputName
 
std::vector< boolfReverseViews
 Do we need to reverse any views? More...
 
std::unique_ptr< tf::RegCNNGraphfTFGraph
 Tensorflow graph. More...
 

Detailed Description

Wrapper for caffe::Net which handles construction and prediction.

Definition at line 23 of file TFRegNetHandler.h.

Constructor & Destructor Documentation

cnn::TFRegNetHandler::TFRegNetHandler ( const fhicl::ParameterSet pset)

Constructor which takes a pset with DeployProto and ModelFile fields.

Definition at line 23 of file TFRegNetHandler.cxx.

23  :
24  fLibPath(cet::getenv(pset.get<std::string>("LibPath", ""))),
25  //fLibPath((pset.get<std::string>("LibPath", ""))),
26  fTFProtoBuf (fLibPath+"/"+pset.get<std::string>("TFProtoBuf")),
27  fInputs(pset.get<unsigned int>("NInputs")),
28  fOutputName(pset.get<std::vector<std::string>>("OutputName")),
29  fReverseViews(pset.get<std::vector<bool> >("ReverseViews"))
30  {
31 
32  // Construct the TF Graph object. The empty vector {} is used since the protobuf
33  // file gives the names of the output layer nodes
34  mf::LogInfo("TFRegNetHandler") << "Loading network: " << fTFProtoBuf << std::endl;
35  std::cout<<"Loading network: "<<fTFProtoBuf<<std::endl;
36  //fTFGraph = tf::RegCNNGraph::create(fTFProtoBuf.c_str(),fInputs,{});
38  if(!fTFGraph){
39  art::Exception(art::errors::Unknown) << "Tensorflow model not found or incorrect";
40  }
41 
42  }
std::string string
Definition: nybbler.cc:12
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
unsigned int fInputs
Number of tdcs for the network to classify.
std::string fLibPath
Library path (typically dune_pardata...)
std::vector< bool > fReverseViews
Do we need to reverse any views?
std::string getenv(std::string const &name)
Definition: getenv.cc:15
T get(std::string const &key) const
Definition: ParameterSet.h:271
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
static std::unique_ptr< RegCNNGraph > create(const char *graph_file_name, const unsigned int &ninputs, const std::vector< std::string > &outputs={})
std::vector< std::string > fOutputName
std::unique_ptr< tf::RegCNNGraph > fTFGraph
Tensorflow graph.
QTextStream & endl(QTextStream &s)
std::string fTFProtoBuf
location of the tf .pb file in the above path

Member Function Documentation

std::vector< float > cnn::TFRegNetHandler::Predict ( const RegPixelMap pm)

Return prediction arrays for RegPixelMap.

Definition at line 60 of file TFRegNetHandler.cxx.

61  {
62 
63  RegCNNImageUtils imageUtils;
64 
65  // Configure the image utility
66  imageUtils.SetViewReversal(fReverseViews);
67 
68  //std::cout << "Reverse views? [" << fReverseViews[0] << "," << fReverseViews[1] << "," << fReverseViews[2] << "]" << std::endl;
69 
70  ImageVectorF thisImage;
71  imageUtils.ConvertPixelMapToImageVectorF(pm,thisImage);
72  std::vector<ImageVectorF> vecForTF;
73 /*
74  // Does this image look sensible?
75  TCanvas *can = new TCanvas("can","",0,0,800,600);
76  unsigned int fImageWires = pm.fNWire;
77  unsigned int fImageTDCs = pm.fNTdc;
78  TH2D* hView0 = new TH2D("hView0","",fImageWires,0,fImageWires,fImageTDCs,0,fImageTDCs);
79  TH2D* hView1 = new TH2D("hView1","",fImageWires,0,fImageWires,fImageTDCs,0,fImageTDCs);
80  TH2D* hView2 = new TH2D("hView2","",fImageWires,0,fImageWires,fImageTDCs,0,fImageTDCs);
81  for(unsigned int w = 0; w < fImageWires; ++w){
82  for(unsigned int t = 0; t < fImageTDCs; ++t){
83  hView0->SetBinContent(w+1,t+1,thisImage[w][t][0]);
84  hView1->SetBinContent(w+1,t+1,thisImage[w][t][1]);
85  hView2->SetBinContent(w+1,t+1,thisImage[w][t][2]);
86  }
87  }
88  hView0->Draw("colz");
89  can->Print("view0a.png");
90  hView1->Draw("colz");
91  can->Print("view1a.png");
92  hView2->Draw("colz");
93  can->Print("view2a.png");
94 */
95  vecForTF.push_back(thisImage);
96 
97  auto cnnResults = fTFGraph->run(vecForTF, fInputs);
98 
99  //std::cout << "Number of CNN result vectors " << cnnResults.size() << " with " << cnnResults[0].size() << " categories" << std::endl;
100 
101  //std::cout << "summary: ";
102  //for(auto const v : cnnResults[0]){
103  // std::cout << v << ", ";
104  //}
105  //std::cout << std::endl;
106 
107  return cnnResults[0];
108  }
unsigned int fInputs
Number of tdcs for the network to classify.
std::vector< bool > fReverseViews
Do we need to reverse any views?
std::vector< ViewVectorF > ImageVectorF
std::unique_ptr< tf::RegCNNGraph > fTFGraph
Tensorflow graph.
std::vector< float > cnn::TFRegNetHandler::Predict ( const RegPixelMap pm,
const std::vector< float >  cm_list 
)

Definition at line 43 of file TFRegNetHandler.cxx.

44  {
45 
46  RegCNNImageUtils imageUtils;
47 
48  // Configure the image utility
49  imageUtils.SetViewReversal(fReverseViews);
50  ImageVectorF thisImage;
51  imageUtils.ConvertPixelMapToImageVectorF(pm,thisImage);
52  std::vector<ImageVectorF> vecForTF;
53  vecForTF.push_back(thisImage);
54 
55  auto cnnResults = fTFGraph->run(vecForTF, cm_list, fInputs);
56  return cnnResults[0];
57  }
unsigned int fInputs
Number of tdcs for the network to classify.
std::vector< bool > fReverseViews
Do we need to reverse any views?
std::vector< ViewVectorF > ImageVectorF
std::unique_ptr< tf::RegCNNGraph > fTFGraph
Tensorflow graph.
std::vector< float > cnn::TFRegNetHandler::PredictNuEEnergy ( const RegPixelMap pm)

Definition at line 110 of file TFRegNetHandler.cxx.

110  {
111  std::vector<float> fullResults = this->Predict(pm);
112  std::vector<float> nue_energy;
113  nue_energy.push_back(fullResults[0]);
114  return nue_energy;
115  }
std::vector< float > Predict(const RegPixelMap &pm)
Return prediction arrays for RegPixelMap.

Member Data Documentation

unsigned int cnn::TFRegNetHandler::fInputs
private

Number of tdcs for the network to classify.

Definition at line 40 of file TFRegNetHandler.h.

std::string cnn::TFRegNetHandler::fLibPath
private

Library path (typically dune_pardata...)

Definition at line 38 of file TFRegNetHandler.h.

std::vector<std::string> cnn::TFRegNetHandler::fOutputName
private

Definition at line 41 of file TFRegNetHandler.h.

std::vector<bool> cnn::TFRegNetHandler::fReverseViews
private

Do we need to reverse any views?

Definition at line 42 of file TFRegNetHandler.h.

std::unique_ptr<tf::RegCNNGraph> cnn::TFRegNetHandler::fTFGraph
private

Tensorflow graph.

Definition at line 43 of file TFRegNetHandler.h.

std::string cnn::TFRegNetHandler::fTFProtoBuf
private

location of the tf .pb file in the above path

Definition at line 39 of file TFRegNetHandler.h.


The documentation for this class was generated from the following files: