TFRegNetHandler.cxx
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 /// \file TFRegNetHandler.cxx
3 /// \brief TFRegNetHandler for RegCNN modified from TFNetHandler.cxx
4 /// \author Ilsoo Seong - iseong@uci.edu
5 ////////////////////////////////////////////////////////////////////////
6 
7 #include <iostream>
8 #include <string>
9 #include "cetlib/getenv.h"
10 
13 
16 
17 #include "TH2D.h"
18 #include "TCanvas.h"
19 
20 namespace cnn
21 {
22 
24  fLibPath(cet::getenv(pset.get<std::string>("LibPath", ""))),
25  //fLibPath((pset.get<std::string>("LibPath", ""))),
26  fTFProtoBuf (fLibPath+"/"+pset.get<std::string>("TFProtoBuf")),
27  fInputs(pset.get<unsigned int>("NInputs")),
28  fOutputName(pset.get<std::vector<std::string>>("OutputName")),
29  fReverseViews(pset.get<std::vector<bool> >("ReverseViews"))
30  {
31 
32  // Construct the TF Graph object. The empty vector {} is used since the protobuf
33  // file gives the names of the output layer nodes
34  mf::LogInfo("TFRegNetHandler") << "Loading network: " << fTFProtoBuf << std::endl;
35  std::cout<<"Loading network: "<<fTFProtoBuf<<std::endl;
36  //fTFGraph = tf::RegCNNGraph::create(fTFProtoBuf.c_str(),fInputs,{});
38  if(!fTFGraph){
39  art::Exception(art::errors::Unknown) << "Tensorflow model not found or incorrect";
40  }
41 
42  }
43  std::vector<float> TFRegNetHandler::Predict(const RegPixelMap& pm, const std::vector<float> cm_list)
44  {
45 
46  RegCNNImageUtils imageUtils;
47 
48  // Configure the image utility
49  imageUtils.SetViewReversal(fReverseViews);
50  ImageVectorF thisImage;
51  imageUtils.ConvertPixelMapToImageVectorF(pm,thisImage);
52  std::vector<ImageVectorF> vecForTF;
53  vecForTF.push_back(thisImage);
54 
55  auto cnnResults = fTFGraph->run(vecForTF, cm_list, fInputs);
56  return cnnResults[0];
57  }
58 
59 
60  std::vector<float> TFRegNetHandler::Predict(const RegPixelMap& pm)
61  {
62 
63  RegCNNImageUtils imageUtils;
64 
65  // Configure the image utility
66  imageUtils.SetViewReversal(fReverseViews);
67 
68  //std::cout << "Reverse views? [" << fReverseViews[0] << "," << fReverseViews[1] << "," << fReverseViews[2] << "]" << std::endl;
69 
70  ImageVectorF thisImage;
71  imageUtils.ConvertPixelMapToImageVectorF(pm,thisImage);
72  std::vector<ImageVectorF> vecForTF;
73 /*
74  // Does this image look sensible?
75  TCanvas *can = new TCanvas("can","",0,0,800,600);
76  unsigned int fImageWires = pm.fNWire;
77  unsigned int fImageTDCs = pm.fNTdc;
78  TH2D* hView0 = new TH2D("hView0","",fImageWires,0,fImageWires,fImageTDCs,0,fImageTDCs);
79  TH2D* hView1 = new TH2D("hView1","",fImageWires,0,fImageWires,fImageTDCs,0,fImageTDCs);
80  TH2D* hView2 = new TH2D("hView2","",fImageWires,0,fImageWires,fImageTDCs,0,fImageTDCs);
81  for(unsigned int w = 0; w < fImageWires; ++w){
82  for(unsigned int t = 0; t < fImageTDCs; ++t){
83  hView0->SetBinContent(w+1,t+1,thisImage[w][t][0]);
84  hView1->SetBinContent(w+1,t+1,thisImage[w][t][1]);
85  hView2->SetBinContent(w+1,t+1,thisImage[w][t][2]);
86  }
87  }
88  hView0->Draw("colz");
89  can->Print("view0a.png");
90  hView1->Draw("colz");
91  can->Print("view1a.png");
92  hView2->Draw("colz");
93  can->Print("view2a.png");
94 */
95  vecForTF.push_back(thisImage);
96 
97  auto cnnResults = fTFGraph->run(vecForTF, fInputs);
98 
99  //std::cout << "Number of CNN result vectors " << cnnResults.size() << " with " << cnnResults[0].size() << " categories" << std::endl;
100 
101  //std::cout << "summary: ";
102  //for(auto const v : cnnResults[0]){
103  // std::cout << v << ", ";
104  //}
105  //std::cout << std::endl;
106 
107  return cnnResults[0];
108  }
109 
111  std::vector<float> fullResults = this->Predict(pm);
112  std::vector<float> nue_energy;
113  nue_energy.push_back(fullResults[0]);
114  return nue_energy;
115  }
116 
117 }
118 
std::string string
Definition: nybbler.cc:12
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
RegPixelMap, basic input to CNN neural net.
Definition: RegPixelMap.h:22
std::vector< float > PredictNuEEnergy(const RegPixelMap &pm)
STL namespace.
unsigned int fInputs
Number of tdcs for the network to classify.
std::vector< bool > fReverseViews
Do we need to reverse any views?
void SetViewReversal(bool reverseX, bool reverseY, bool reverseZ)
Function to set any views that need reversing.
std::vector< ViewVectorF > ImageVectorF
std::vector< float > Predict(const RegPixelMap &pm)
Return prediction arrays for RegPixelMap.
std::string getenv(std::string const &name)
Definition: getenv.cc:15
Utilities for producing images for the RegCNN.
void ConvertPixelMapToImageVectorF(const RegPixelMap &pm, ImageVectorF &imageVec)
Convert a pixel map into an image vector (float version)
cet::coded_exception< errors::ErrorCodes, ExceptionDetail::translate > Exception
Definition: Exception.h:66
Defines an enumeration for cellhit classification.
TFRegNetHandler(const fhicl::ParameterSet &pset)
Constructor which takes a pset with DeployProto and ModelFile fields.
TFRegNetHandler for RegCNN modified from TFNetHandler.h.
static std::unique_ptr< RegCNNGraph > create(const char *graph_file_name, const unsigned int &ninputs, const std::vector< std::string > &outputs={})
Class containing some utility functions for all things RegCNN.
auto const & get(AssnsNode< L, R, D > const &r)
Definition: AssnsNode.h:115
std::vector< std::string > fOutputName
int bool
Definition: qglobal.h:345
std::unique_ptr< tf::RegCNNGraph > fTFGraph
Tensorflow graph.
QTextStream & endl(QTextStream &s)
std::string fTFProtoBuf
location of the tf .pb file in the above path