Public Member Functions | Private Attributes | List of all members
PointIdAlgTools::PointIdAlgTriton Class Reference
Inheritance diagram for PointIdAlgTools::PointIdAlgTriton:
PointIdAlgTools::IPointIdAlg img::DataProviderAlg

Public Member Functions

 PointIdAlgTriton (fhicl::Table< Config > const &table)
 
std::vector< float > Run (std::vector< std::vector< float >> const &inp2d) const override
 
std::vector< std::vector< float > > Run (std::vector< std::vector< std::vector< float >>> const &inps, int samples=-1) const override
 
- Public Member Functions inherited from PointIdAlgTools::IPointIdAlg
virtual ~IPointIdAlg () noexcept=default
 
float predictIdValue (unsigned int wire, float drift, size_t outIdx=0)
 
std::vector< float > predictIdVector (unsigned int wire, float drift)
 
std::vector< std::vector< float > > predictIdVectors (const std::vector< std::pair< unsigned int, float >> &points)
 
std::vector< std::string > const & outputLabels (void) const
 
bool isInsideFiducialRegion (unsigned int wire, float drift) const
 
- Public Member Functions inherited from img::DataProviderAlg
 DataProviderAlg (const fhicl::ParameterSet &pset)
 
 DataProviderAlg (const Config &config)
 
virtual ~DataProviderAlg ()
 
bool setWireDriftData (const detinfo::DetectorClocksData &clock_data, const detinfo::DetectorPropertiesData &det_prop, const std::vector< recob::Wire > &wires, unsigned int plane, unsigned int tpc, unsigned int cryo)
 
std::vector< float > const & wireData (size_t widx) const
 
std::vector< std::vector< float > > getPatch (size_t wire, float drift, size_t patchSizeW, size_t patchSizeD) const
 
float getPixelOrZero (int wire, int drift) const
 
double getAdcSum () const
 
size_t getAdcArea () const
 
float poolMax (int wire, int drift, size_t r=0) const
 Pool max value in a patch around the wire/drift pixel. More...
 
unsigned int Cryo () const
 Pool sum of pixels in a patch around the wire/drift pixel. More...
 
unsigned int TPC () const
 
unsigned int Plane () const
 
unsigned int NWires () const
 
unsigned int NScaledDrifts () const
 
unsigned int NCachedDrifts () const
 
unsigned int DriftWindow () const
 
float ZeroLevel () const
 Level of zero ADC after scaling. More...
 
double LifetimeCorrection (detinfo::DetectorClocksData const &clock_data, detinfo::DetectorPropertiesData const &det_prop, double tick) const
 

Private Attributes

std::string fTritonModelName
 
std::string fTritonURL
 
bool fTritonVerbose
 
std::string fTritonModelVersion
 
std::unique_ptr< nic::InferenceServerGrpcClient > triton_client
 
inference::ModelMetadataResponse triton_modmet
 
inference::ModelConfigResponse triton_modcfg
 
std::vector< int64_t > triton_inpshape
 
nic::InferOptions triton_options
 

Additional Inherited Members

- Public Types inherited from img::DataProviderAlg
enum  EDownscaleMode { kMax = 1, kMaxMean = 2, kMean = 3 }
 
- Protected Member Functions inherited from PointIdAlgTools::IPointIdAlg
bool bufferPatch (size_t wire, float drift, std::vector< std::vector< float >> &patch)
 
bool bufferPatch (size_t wire, float drift)
 
void resizePatch (void)
 
- Protected Member Functions inherited from img::DataProviderAlg
std::vector< float > downscaleMax (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscaleMaxMean (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscaleMean (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
std::vector< float > downscale (std::size_t dst_size, std::vector< float > const &adc, size_t tick0) const
 
size_t getDriftIndex (float drift) const
 
std::optional< std::vector< float > > setWireData (std::vector< float > const &adc, size_t wireIdx) const
 
bool patchFromDownsampledView (size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
 
bool patchFromOriginalView (size_t wire, float drift, size_t size_w, size_t size_d, std::vector< std::vector< float >> &patch) const
 
virtual DataProviderAlgView resizeView (detinfo::DetectorClocksData const &clock_data, detinfo::DetectorPropertiesData const &det_prop, size_t wires, size_t drifts)
 
- Protected Attributes inherited from PointIdAlgTools::IPointIdAlg
std::vector< std::stringfNNetOutputs
 
size_t fPatchSizeW
 
size_t fPatchSizeD
 
std::vector< std::vector< float > > fWireDriftPatch
 
size_t fCurrentWireIdx
 
size_t fCurrentScaledDrift
 
- Protected Attributes inherited from img::DataProviderAlg
DataProviderAlgView fAlgView
 
EDownscaleMode fDownscaleMode
 
size_t fDriftWindow
 
bool fDownscaleFullView
 
float fDriftWindowInv
 
calo::CalorimetryAlg fCalorimetryAlg
 
geo::GeometryCore const * fGeometry
 

Detailed Description

Definition at line 17 of file PointIdAlgTriton_tool.cc.

Constructor & Destructor Documentation

PointIdAlgTools::PointIdAlgTriton::PointIdAlgTriton ( fhicl::Table< Config > const &  table)
explicit

Definition at line 41 of file PointIdAlgTriton_tool.cc.

42  : img::DataProviderAlg(table()), triton_options("")
43  {
44  // ... Get common config vars
45  fNNetOutputs = table().NNetOutputs();
46  fPatchSizeW = table().PatchSizeW();
47  fPatchSizeD = table().PatchSizeD();
48  fCurrentWireIdx = 99999;
49  fCurrentScaledDrift = 99999;
50 
51  // ... Get "optional" config vars specific to Triton interface
52  fTritonModelName = table().TritonModelName();
53  fTritonURL = table().TritonURL();
54  fTritonVerbose = table().TritonVerbose();
55  fTritonModelVersion = table().TritonModelVersion();
56 
57  // ... Create the Triton inference client
58  auto err = nic::InferenceServerGrpcClient::Create(&triton_client, fTritonURL, fTritonVerbose);
59  if (!err.IsOk()) {
60  throw cet::exception("PointIdAlgTriton")
61  << "error: unable to create client for inference: " << err << std::endl;
62  }
63 
64  // ... Get the model metadata and config information
66  if (!err.IsOk()) {
67  throw cet::exception("PointIdAlgTriton")
68  << "error: failed to get model metadata: " << err << std::endl;
69  }
71  if (!err.IsOk()) {
72  throw cet::exception("PointIdAlgTriton")
73  << "error: failed to get model config: " << err << std::endl;
74  }
75 
76  // ... Set up shape vector needed when creating inference input
77  triton_inpshape.push_back(1); // initialize batch_size to 1
78  triton_inpshape.push_back(triton_modmet.inputs(0).shape(1));
79  triton_inpshape.push_back(triton_modmet.inputs(0).shape(2));
80  triton_inpshape.push_back(triton_modmet.inputs(0).shape(3));
81 
82  // ... Set up Triton inference client options
83  triton_options.model_name_ = fTritonModelName;
84  triton_options.model_version_ = fTritonModelVersion;
85 
86  mf::LogInfo("PointIdAlgTriton") << "url: " << fTritonURL;
87  mf::LogInfo("PointIdAlgTriton") << "model name: " << fTritonModelName;
88  mf::LogInfo("PointIdAlgTriton") << "model version: " << fTritonModelVersion;
89  mf::LogInfo("PointIdAlgTriton") << "verbose: " << fTritonVerbose;
90 
91  mf::LogInfo("PointIdAlgTriton") << "tensorRT inference context created.";
92 
93  resizePatch();
94  }
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
inference::ModelMetadataResponse triton_modmet
std::vector< std::string > fNNetOutputs
Definition: IPointIdAlg.h:152
inference::ModelConfigResponse triton_modcfg
void err(const char *fmt,...)
Definition: message.cpp:226
std::unique_ptr< nic::InferenceServerGrpcClient > triton_client
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)

Member Function Documentation

std::vector< float > PointIdAlgTools::PointIdAlgTriton::Run ( std::vector< std::vector< float >> const &  inp2d) const
overridevirtual

Implements PointIdAlgTools::IPointIdAlg.

Definition at line 98 of file PointIdAlgTriton_tool.cc.

99  {
100  size_t nrows = inp2d.size(), ncols = inp2d.front().size();
101 
102  triton_inpshape.at(0) = 1; // set batch size
103 
104  // ~~~~ Initialize the inputs
105 
106  nic::InferInput* triton_input;
107  auto err = nic::InferInput::Create(
108  &triton_input, triton_modmet.inputs(0).name(), triton_inpshape, triton_modmet.inputs(0).datatype() );
109  if (!err.IsOk()) {
110  throw cet::exception("PointIdAlgTriton")
111  << "unable to get input: " << err << std::endl;
112  }
113  std::shared_ptr<nic::InferInput> triton_input_ptr(triton_input);
114  std::vector<nic::InferInput*> triton_inputs = {triton_input_ptr.get()};
115 
116  // ~~~~ Register the mem address of 1st byte of image and #bytes in image
117 
118  err = triton_input_ptr->Reset();
119  if (!err.IsOk()) {
120  throw cet::exception("PointIdAlgTriton")
121  << "failed resetting Triton model input: " << err << std::endl;
122  }
123 
124  size_t sbuff_byte_size = (nrows * ncols) * sizeof(float);
125  std::vector<float> fa(sbuff_byte_size);
126 
127  // ..flatten the 2d array into contiguous 1d block
128  for (size_t ir = 0; ir < nrows; ++ir) {
129  std::copy(inp2d[ir].begin(), inp2d[ir].end(), fa.begin() + (ir * ncols));
130  }
131  err = triton_input_ptr->AppendRaw(reinterpret_cast<uint8_t*>(fa.data()), sbuff_byte_size);
132  if (!err.IsOk()) {
133  throw cet::exception("PointIdAlgTriton") << "failed setting Triton input: " << err << std::endl;
134  }
135 
136  // ~~~~ Send inference request
137 
138  nic::InferResult* results;
139 
140  err = triton_client->Infer(&results, triton_options, triton_inputs);
141  if (!err.IsOk()) {
142  throw cet::exception("PointIdAlgTriton")
143  << "failed sending Triton synchronous infer request: " << err << std::endl;
144  }
145  std::shared_ptr<nic::InferResult> results_ptr;
146  results_ptr.reset(results);
147 
148  // ~~~~ Retrieve inference results
149 
150  std::vector<float> out;
151 
152  const float *prb0;
153  size_t rbuff0_byte_size; // size of result buffer in bytes
154  results_ptr->RawData(triton_modmet.outputs(0).name(), (const uint8_t**)&prb0, &rbuff0_byte_size);
155  size_t ncat0 = rbuff0_byte_size/sizeof(float);
156 
157  const float *prb1;
158  size_t rbuff1_byte_size; // size of result buffer in bytes
159  results_ptr->RawData(triton_modmet.outputs(1).name(), (const uint8_t**)&prb1, &rbuff1_byte_size);
160  size_t ncat1 = rbuff1_byte_size/sizeof(float);
161 
162  for(unsigned j = 0; j < ncat0; j++) out.push_back(*(prb0 + j ));
163  for(unsigned j = 0; j < ncat1; j++) out.push_back(*(prb1 + j ));
164 
165  return out;
166  }
end
while True: pbar.update(maxval-len(onlies[E][S])) #print iS, "/", len(onlies[E][S]) found = False for...
inference::ModelMetadataResponse triton_modmet
void err(const char *fmt,...)
Definition: message.cpp:226
std::unique_ptr< nic::InferenceServerGrpcClient > triton_client
T copy(T const &v)
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:72
unsigned nrows(sqlite3 *db, std::string const &tablename)
Definition: helpers.cc:82
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)
std::vector< std::vector< float > > PointIdAlgTools::PointIdAlgTriton::Run ( std::vector< std::vector< std::vector< float >>> const &  inps,
int  samples = -1 
) const
overridevirtual

Implements PointIdAlgTools::IPointIdAlg.

Definition at line 170 of file PointIdAlgTriton_tool.cc.

171  {
172  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
173  return std::vector<std::vector<float>>();
174  }
175 
176  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
177 
178  size_t usamples = samples;
179  size_t nrows = inps.front().size(), ncols = inps.front().front().size();
180 
181  triton_inpshape.at(0) = usamples; // set batch size
182 
183  // ~~~~ Initialize the inputs
184 
185  nic::InferInput* triton_input;
186  auto err = nic::InferInput::Create(
187  &triton_input, triton_modmet.inputs(0).name(), triton_inpshape, triton_modmet.inputs(0).datatype() );
188  if (!err.IsOk()) {
189  throw cet::exception("PointIdAlgTriton")
190  << "unable to get input: " << err << std::endl;
191  }
192  std::shared_ptr<nic::InferInput> triton_input_ptr(triton_input);
193  std::vector<nic::InferInput*> triton_inputs = {triton_input_ptr.get()};
194 
195  // ~~~~ For each sample, register the mem address of 1st byte of image and #bytes in image
196  err = triton_input_ptr->Reset();
197  if (!err.IsOk()) {
198  throw cet::exception("PointIdAlgTriton")
199  << "failed resetting Triton model input: " << err << std::endl;
200  }
201 
202  size_t sbuff_byte_size = (nrows * ncols) * sizeof(float);
203  std::vector<std::vector<float>> fa(usamples, std::vector<float>(sbuff_byte_size));
204 
205  for (size_t idx = 0; idx < usamples; ++idx) {
206  // ..first flatten the 2d array into contiguous 1d block
207  for (size_t ir = 0; ir < nrows; ++ir) {
208  std::copy(inps[idx][ir].begin(), inps[idx][ir].end(), fa[idx].begin() + (ir * ncols));
209  }
210  err = triton_input_ptr->AppendRaw(reinterpret_cast<uint8_t*>(fa[idx].data()), sbuff_byte_size);
211  if (!err.IsOk()) {
212  throw cet::exception("PointIdAlgTriton")
213  << "failed setting Triton input: " << err << std::endl;
214  }
215  }
216 
217  // ~~~~ Send inference request
218 
219  nic::InferResult* results;
220 
221  err = triton_client->Infer(&results, triton_options, triton_inputs);
222  if (!err.IsOk()) {
223  throw cet::exception("PointIdAlgTriton")
224  << "failed sending Triton synchronous infer request: " << err << std::endl;
225  }
226  std::shared_ptr<nic::InferResult> results_ptr;
227  results_ptr.reset(results);
228 
229  // ~~~~ Retrieve inference results
230 
231  std::vector<std::vector<float>> out;
232 
233  const float *prb0;
234  size_t rbuff0_byte_size; // size of result buffer in bytes
235  results_ptr->RawData(triton_modmet.outputs(0).name(), (const uint8_t**)&prb0, &rbuff0_byte_size);
236  size_t ncat0 = rbuff0_byte_size/(usamples*sizeof(float));
237 
238  const float *prb1;
239  size_t rbuff1_byte_size; // size of result buffer in bytes
240  results_ptr->RawData(triton_modmet.outputs(1).name(), (const uint8_t**)&prb1, &rbuff1_byte_size);
241  size_t ncat1 = rbuff1_byte_size/(usamples*sizeof(float));
242 
243  for(unsigned i = 0; i < usamples; i++) {
244  std::vector<float> vprb;
245  for(unsigned j = 0; j < ncat0; j++) vprb.push_back(*(prb0 + i*ncat0 + j ));
246  for(unsigned j = 0; j < ncat1; j++) vprb.push_back(*(prb1 + i*ncat1 + j ));
247  out.push_back(vprb);
248  }
249 
250  return out;
251  }
end
while True: pbar.update(maxval-len(onlies[E][S])) #print iS, "/", len(onlies[E][S]) found = False for...
inference::ModelMetadataResponse triton_modmet
void err(const char *fmt,...)
Definition: message.cpp:226
std::unique_ptr< nic::InferenceServerGrpcClient > triton_client
T copy(T const &v)
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:72
unsigned nrows(sqlite3 *db, std::string const &tablename)
Definition: helpers.cc:82
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)

Member Data Documentation

std::string PointIdAlgTools::PointIdAlgTriton::fTritonModelName
private

Definition at line 26 of file PointIdAlgTriton_tool.cc.

std::string PointIdAlgTools::PointIdAlgTriton::fTritonModelVersion
private

Definition at line 29 of file PointIdAlgTriton_tool.cc.

std::string PointIdAlgTools::PointIdAlgTriton::fTritonURL
private

Definition at line 27 of file PointIdAlgTriton_tool.cc.

bool PointIdAlgTools::PointIdAlgTriton::fTritonVerbose
private

Definition at line 28 of file PointIdAlgTriton_tool.cc.

std::unique_ptr<nic::InferenceServerGrpcClient> PointIdAlgTools::PointIdAlgTriton::triton_client
private

Definition at line 31 of file PointIdAlgTriton_tool.cc.

std::vector<int64_t> PointIdAlgTools::PointIdAlgTriton::triton_inpshape
mutableprivate

Definition at line 34 of file PointIdAlgTriton_tool.cc.

inference::ModelConfigResponse PointIdAlgTools::PointIdAlgTriton::triton_modcfg
private

Definition at line 33 of file PointIdAlgTriton_tool.cc.

inference::ModelMetadataResponse PointIdAlgTools::PointIdAlgTriton::triton_modmet
private

Definition at line 32 of file PointIdAlgTriton_tool.cc.

nic::InferOptions PointIdAlgTools::PointIdAlgTriton::triton_options
private

Definition at line 35 of file PointIdAlgTriton_tool.cc.


The documentation for this class was generated from the following file: