All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Public Member Functions | Private Attributes | List of all members
PointIdAlgTools::PointIdAlgTrtis Class Reference
Inheritance diagram for PointIdAlgTools::PointIdAlgTrtis:
PointIdAlgTools::IPointIdAlg img::DataProviderAlg

Public Member Functions

 PointIdAlgTrtis (fhicl::Table< Config > const &table)
 
std::vector< float > Run (std::vector< std::vector< float >> const &inp2d) const override
 
std::vector< std::vector< float > > Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1) const override
 
- Public Member Functions inherited from PointIdAlgTools::IPointIdAlg
virtual ~IPointIdAlg () noexcept=default
 
float predictIdValue (unsigned int wire, float drift, size_t outIdx=0)
 
std::vector< float > predictIdVector (unsigned int wire, float drift)
 
std::vector< std::vector< float > > predictIdVectors (const std::vector< std::pair< unsigned int, float >> &points)
 
std::vector< std::string > const & outputLabels (void) const
 
bool isInsideFiducialRegion (unsigned int wire, float drift) const
 
- Public Member Functions inherited from img::DataProviderAlg
 DataProviderAlg (const fhicl::ParameterSet &pset)
 
 DataProviderAlg (const Config &config)
 
virtual ~DataProviderAlg ()
 
bool setWireDriftData (const detinfo::DetectorClocksData &clock_data, const detinfo::DetectorPropertiesData &det_prop, const std::vector< recob::Wire > &wires, unsigned int plane, unsigned int tpc, unsigned int cryo)
 
std::vector< float > const & wireData (size_t widx) const
 
std::vector< std::vector< float > > getPatch (size_t wire, float drift, size_t patchSizeW, size_t patchSizeD) const
 
float getPixelOrZero (int wire, int drift) const
 
double getAdcSum () const
 
size_t getAdcArea () const
 
float poolMax (int wire, int drift, size_t r=0) const
 Pool max value in a patch around the wire/drift pixel. More...
 
float poolSum (int wire, int drift, size_t r=0) const
 Pool sum of pixels in a patch around the wire/drift pixel. More...
 
unsigned int Cryo () const
 
unsigned int TPC () const
 
unsigned int Plane () const
 
unsigned int NWires () const
 
unsigned int NScaledDrifts () const
 
unsigned int NCachedDrifts () const
 
unsigned int DriftWindow () const
 
float ZeroLevel () const
 Level of zero ADC after scaling. More...
 
double LifetimeCorrection (detinfo::DetectorClocksData const &clock_data, detinfo::DetectorPropertiesData const &det_prop, double tick) const
 

Private Attributes

std::string fTrtisModelName
 
std::string fTrtisURL
 
bool fTrtisVerbose
 
int64_t fTrtisModelVersion
 
std::unique_ptr< nic::InferContext > ctx
 
std::shared_ptr< nic::InferContext::Input > model_input
 

Additional Inherited Members

- Public Types inherited from img::DataProviderAlg
enum  EDownscaleMode { kMax = 1, kMaxMean = 2, kMean = 3 }
 
- Protected Member Functions inherited from PointIdAlgTools::IPointIdAlg
bool bufferPatch (size_t wire, float drift, std::vector< std::vector< float >> &patch)
 
bool bufferPatch (size_t wire, float drift)
 
void resizePatch (void)
 
- Protected Member Functions inherited from img::DataProviderAlg
void downscaleMax (std::vector< float > &dst, std::vector< float > const &adc, size_t tick0) const
 
void downscaleMaxMean (std::vector< float > &dst, std::vector< float > const &adc, size_t tick0) const
 
void downscaleMean (std::vector< float > &dst, std::vector< float > const &adc, size_t tick0) const
 
void downscale (std::vector< float > &dst, std::vector< float > const &adc, size_t tick0) const
 
size_t getDriftIndex (float drift) const
 
bool setWireData (std::vector< float > const &adc, size_t wireIdx)
 
bool patchFromDownsampledView (size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
 
bool patchFromOriginalView (size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
 
virtual void resizeView (detinfo::DetectorClocksData const &clock_data, detinfo::DetectorPropertiesData const &det_prop, size_t wires, size_t drifts)
 
- Protected Attributes inherited from PointIdAlgTools::IPointIdAlg
std::vector< std::stringfNNetOutputs
 
size_t fPatchSizeW
 
size_t fPatchSizeD
 
std::vector< std::vector< float > > fWireDriftPatch
 
size_t fCurrentWireIdx
 
size_t fCurrentScaledDrift
 
- Protected Attributes inherited from img::DataProviderAlg
unsigned int fCryo
 
unsigned int fTPC
 
unsigned int fPlane
 
unsigned int fNWires
 
unsigned int fNDrifts
 
unsigned int fNScaledDrifts
 
unsigned int fNCachedDrifts
 
std::vector< raw::ChannelID_tfWireChannels
 
std::vector< std::vector< float > > fWireDriftData
 
std::vector< float > fLifetimeCorrFactors
 
EDownscaleMode fDownscaleMode
 
size_t fDriftWindow
 
bool fDownscaleFullView
 
float fDriftWindowInv
 
calo::CalorimetryAlg fCalorimetryAlg
 
geo::GeometryCore const * fGeometry
 

Detailed Description

Definition at line 19 of file PointIdAlgTrtis_tool.cc.

Constructor & Destructor Documentation

PointIdAlgTools::PointIdAlgTrtis::PointIdAlgTrtis ( fhicl::Table< Config > const &  table)
explicit

Definition at line 38 of file PointIdAlgTrtis_tool.cc.

39  : img::DataProviderAlg(table())
40  {
41  // ... Get common config vars
42  fNNetOutputs = table().NNetOutputs();
43  fPatchSizeW = table().PatchSizeW();
44  fPatchSizeD = table().PatchSizeD();
45  fCurrentWireIdx = 99999;
46  fCurrentScaledDrift = 99999;
47 
48  // ... Get "optional" config vars specific to tRTis interface
49  std::string s_cfgvr;
50  int64_t i_cfgvr;
51  bool b_cfgvr;
52  if (table().TrtisModelName(s_cfgvr)) { fTrtisModelName = s_cfgvr; }
53  else {
54  fTrtisModelName = "mycnn";
55  }
56  if (table().TrtisURL(s_cfgvr)) { fTrtisURL = s_cfgvr; }
57  else {
58  fTrtisURL = "localhost:8001";
59  }
60  if (table().TrtisVerbose(b_cfgvr)) { fTrtisVerbose = b_cfgvr; }
61  else {
62  fTrtisVerbose = false;
63  }
64  if (table().TrtisModelVersion(i_cfgvr)) { fTrtisModelVersion = i_cfgvr; }
65  else {
66  fTrtisModelVersion = -1;
67  }
68 
69  // ... Create the inference context for the specified model.
70  auto err = nic::InferGrpcContext::Create(
72  if (!err.IsOk()) {
73  throw cet::exception("PointIdAlgTrtis")
74  << "unable to create tRTis inference context: " << err << std::endl;
75  }
76 
77  // ... Get the specified model input
78  err = ctx->GetInput("main_input", &model_input);
79  if (!err.IsOk()) {
80  throw cet::exception("PointIdAlgTrtis") << "unable to get tRTis input: " << err << std::endl;
81  }
82 
83  mf::LogInfo("PointIdAlgTrtis") << "url: " << fTrtisURL;
84  mf::LogInfo("PointIdAlgTrtis") << "model name: " << fTrtisModelName;
85  mf::LogInfo("PointIdAlgTrtis") << "model version: " << fTrtisModelVersion;
86  mf::LogInfo("PointIdAlgTrtis") << "verbose: " << fTrtisVerbose;
87 
88  mf::LogInfo("PointIdAlgTrtis") << "tensorRT inference context created.";
89 
90  resizePatch();
91  }
std::shared_ptr< nic::InferContext::Input > model_input
std::string string
Definition: nybbler.cc:12
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
std::unique_ptr< nic::InferContext > ctx
std::vector< std::string > fNNetOutputs
Definition: IPointIdAlg.h:144
void err(const char *fmt,...)
Definition: message.cpp:226
signed __int64 int64_t
Definition: stdint.h:135
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)

Member Function Documentation

std::vector< float > PointIdAlgTools::PointIdAlgTrtis::Run ( std::vector< std::vector< float >> const &  inp2d) const
overridevirtual

Implements PointIdAlgTools::IPointIdAlg.

Definition at line 95 of file PointIdAlgTrtis_tool.cc.

96  {
97  size_t nrows = inp2d.size(), ncols = inp2d.front().size();
98 
99  // ~~~~ Configure context options
100 
101  std::unique_ptr<nic::InferContext::Options> options;
102  auto err = nic::InferContext::Options::Create(&options);
103  if (!err.IsOk()) {
104  throw cet::exception("PointIdAlgTrtis")
105  << "failed initializing tRTis infer options: " << err << std::endl;
106  }
107 
108  options->SetBatchSize(1); // set batch size
109  for (const auto& output : ctx->Outputs()) { // request all output tensors
110  options->AddRawResult(output);
111  }
112 
113  err = ctx->SetRunOptions(*options);
114  if (!err.IsOk()) {
115  throw cet::exception("PointIdAlgTrtis")
116  << "unable to set tRTis infer options: " << err << std::endl;
117  }
118 
119  // ~~~~ Register the mem address of 1st byte of image and #bytes in image
120 
121  err = model_input->Reset();
122  if (!err.IsOk()) {
123  throw cet::exception("PointIdAlgTrtis")
124  << "failed resetting tRTis model input: " << err << std::endl;
125  }
126 
127  size_t sbuff_byte_size = (nrows * ncols) * sizeof(float);
128  std::vector<float> fa(sbuff_byte_size);
129 
130  // ..flatten the 2d array into contiguous 1d block
131  for (size_t ir = 0; ir < nrows; ++ir) {
132  std::copy(inp2d[ir].begin(), inp2d[ir].end(), fa.begin() + (ir * ncols));
133  }
134  err = model_input->SetRaw(reinterpret_cast<uint8_t*>(fa.data()), sbuff_byte_size);
135  if (!err.IsOk()) {
136  throw cet::exception("PointIdAlgTrtis") << "failed setting tRTis input: " << err << std::endl;
137  }
138 
139  // ~~~~ Send inference request
140 
141  std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
142 
143  err = ctx->Run(&results);
144  if (!err.IsOk()) {
145  throw cet::exception("PointIdAlgTrtis")
146  << "failed sending tRTis synchronous infer request: " << err << std::endl;
147  }
148 
149  // ~~~~ Retrieve inference results
150 
151  std::vector<float> out;
152  std::map<std::string, std::unique_ptr<nic::InferContext::Result>>::iterator itRes =
153  results.begin();
154 
155  // .. loop over the outputs
156  while (itRes != results.end()) {
157  const std::unique_ptr<nic::InferContext::Result>& result = itRes->second;
158  const uint8_t* rbuff; // pointer to buffer holding result bytes
159  size_t rbuff_byte_size; // size of result buffer in bytes
160  result->GetRaw(0, &rbuff, &rbuff_byte_size);
161  const float* prb = reinterpret_cast<const float*>(rbuff);
162 
163  // .. loop over each class in output
164  size_t ncat = rbuff_byte_size / sizeof(float);
165  for (unsigned int j = 0; j < ncat; j++) {
166  out.push_back(prb[j]);
167  }
168  itRes++;
169  }
170 
171  return out;
172  }
end
while True: pbar.update(maxval-len(onlies[E][S])) #print iS, "/", len(onlies[E][S]) found = False for...
std::shared_ptr< nic::InferContext::Input > model_input
static QCString result
unsigned char uint8_t
Definition: stdint.h:124
std::unique_ptr< nic::InferContext > ctx
void err(const char *fmt,...)
Definition: message.cpp:226
T copy(T const &v)
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:72
unsigned nrows(sqlite3 *db, std::string const &tablename)
Definition: helpers.cc:84
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)
std::vector< std::vector< float > > PointIdAlgTools::PointIdAlgTrtis::Run ( std::vector< std::vector< std::vector< float >>> const &  inps,
int  samples = -1 
) const
overridevirtual

Implements PointIdAlgTools::IPointIdAlg.

Definition at line 176 of file PointIdAlgTrtis_tool.cc.

177  {
178  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
179  return std::vector<std::vector<float>>();
180  }
181 
182  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
183 
184  size_t usamples = samples;
185  size_t nrows = inps.front().size(), ncols = inps.front().front().size();
186 
187  // ~~~~ Configure context options
188 
189  std::unique_ptr<nic::InferContext::Options> options;
190  auto err = nic::InferContext::Options::Create(&options);
191  if (!err.IsOk()) {
192  throw cet::exception("PointIdAlgTrtis")
193  << "failed initializing tRTis infer options: " << err << std::endl;
194  }
195 
196  options->SetBatchSize(usamples); // set batch size
197  for (const auto& output : ctx->Outputs()) { // request all output tensors
198  options->AddRawResult(output);
199  }
200 
201  err = ctx->SetRunOptions(*options);
202  if (!err.IsOk()) {
203  throw cet::exception("PointIdAlgTrtis")
204  << "unable to set tRTis inference options: " << err << std::endl;
205  }
206 
207  // ~~~~ For each sample, register the mem address of 1st byte of image and #bytes in image
208 
209  err = model_input->Reset();
210  if (!err.IsOk()) {
211  throw cet::exception("PointIdAlgTrtis")
212  << "failed resetting tRTis model input: " << err << std::endl;
213  }
214 
215  size_t sbuff_byte_size = (nrows * ncols) * sizeof(float);
216  std::vector<std::vector<float>> fa(usamples, std::vector<float>(sbuff_byte_size));
217 
218  for (size_t idx = 0; idx < usamples; ++idx) {
219  // ..first flatten the 2d array into contiguous 1d block
220  for (size_t ir = 0; ir < nrows; ++ir) {
221  std::copy(inps[idx][ir].begin(), inps[idx][ir].end(), fa[idx].begin() + (ir * ncols));
222  }
223  err = model_input->SetRaw(reinterpret_cast<uint8_t*>(fa[idx].data()), sbuff_byte_size);
224  if (!err.IsOk()) {
225  throw cet::exception("PointIdAlgTrtis")
226  << "failed setting tRTis input: " << err << std::endl;
227  }
228  }
229 
230  // ~~~~ Send inference request
231 
232  std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
233 
234  err = ctx->Run(&results);
235  if (!err.IsOk()) {
236  throw cet::exception("PointIdAlgTrtis")
237  << "failed sending tRTis synchronous infer request: " << err << std::endl;
238  }
239 
240  // ~~~~ Retrieve inference results
241 
242  std::vector<std::vector<float>> out;
243 
244  for (unsigned int i = 0; i < usamples; i++) {
245  std::map<std::string, std::unique_ptr<nic::InferContext::Result>>::iterator itRes =
246  results.begin();
247 
248  // .. loop over the outputs
249  std::vector<float> vprb;
250  while (itRes != results.end()) {
251  const std::unique_ptr<nic::InferContext::Result>& result = itRes->second;
252  const uint8_t* rbuff; // pointer to buffer holding result bytes
253  size_t rbuff_byte_size; // size of result buffer in bytes
254  result->GetRaw(i, &rbuff, &rbuff_byte_size);
255  const float* prb = reinterpret_cast<const float*>(rbuff);
256 
257  // .. loop over each class in output
258  size_t ncat = rbuff_byte_size / sizeof(float);
259  for (unsigned int j = 0; j < ncat; j++) {
260  vprb.push_back(prb[j]);
261  }
262  itRes++;
263  }
264  out.push_back(vprb);
265  }
266 
267  return out;
268  }
end
while True: pbar.update(maxval-len(onlies[E][S])) #print iS, "/", len(onlies[E][S]) found = False for...
std::shared_ptr< nic::InferContext::Input > model_input
static QCString result
unsigned char uint8_t
Definition: stdint.h:124
std::unique_ptr< nic::InferContext > ctx
void err(const char *fmt,...)
Definition: message.cpp:226
T copy(T const &v)
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:72
unsigned nrows(sqlite3 *db, std::string const &tablename)
Definition: helpers.cc:84
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)

Member Data Documentation

std::unique_ptr<nic::InferContext> PointIdAlgTools::PointIdAlgTrtis::ctx
private

Definition at line 33 of file PointIdAlgTrtis_tool.cc.

std::string PointIdAlgTools::PointIdAlgTrtis::fTrtisModelName
private

Definition at line 28 of file PointIdAlgTrtis_tool.cc.

int64_t PointIdAlgTools::PointIdAlgTrtis::fTrtisModelVersion
private

Definition at line 31 of file PointIdAlgTrtis_tool.cc.

std::string PointIdAlgTools::PointIdAlgTrtis::fTrtisURL
private

Definition at line 29 of file PointIdAlgTrtis_tool.cc.

bool PointIdAlgTools::PointIdAlgTrtis::fTrtisVerbose
private

Definition at line 30 of file PointIdAlgTrtis_tool.cc.

std::shared_ptr<nic::InferContext::Input> PointIdAlgTools::PointIdAlgTrtis::model_input
private

Definition at line 34 of file PointIdAlgTrtis_tool.cc.


The documentation for this class was generated from the following file: