PointIdAlgTriton_tool.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 // Class: PointIdAlgTriton_tool
3 // Authors: M.Wang, FNAL, 2021: Nvidia Triton inf client
4 ////////////////////////////////////////////////////////////////////////////////////////////////////
5 
8 
9 // Nvidia Triton inference server client includes
10 #include "grpc_client.h"
11 
12 namespace ni = nvidia::inferenceserver;
13 namespace nic = nvidia::inferenceserver::client;
14 
15 namespace PointIdAlgTools {
16 
17  class PointIdAlgTriton : public IPointIdAlg {
18  public:
19  explicit PointIdAlgTriton(fhicl::Table<Config> const& table);
20 
21  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) const override;
22  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
23  int samples = -1) const override;
24 
25  private:
30 
31  std::unique_ptr<nic::InferenceServerGrpcClient> triton_client;
32  inference::ModelMetadataResponse triton_modmet;
33  inference::ModelConfigResponse triton_modcfg;
34  mutable std::vector<int64_t> triton_inpshape;
35  nic::InferOptions triton_options;
36 
37 
38  };
39 
40  // ------------------------------------------------------
42  : img::DataProviderAlg(table()), triton_options("")
43  {
44  // ... Get common config vars
45  fNNetOutputs = table().NNetOutputs();
46  fPatchSizeW = table().PatchSizeW();
47  fPatchSizeD = table().PatchSizeD();
48  fCurrentWireIdx = 99999;
49  fCurrentScaledDrift = 99999;
50 
51  // ... Get "optional" config vars specific to Triton interface
52  fTritonModelName = table().TritonModelName();
53  fTritonURL = table().TritonURL();
54  fTritonVerbose = table().TritonVerbose();
55  fTritonModelVersion = table().TritonModelVersion();
56 
57  // ... Create the Triton inference client
58  auto err = nic::InferenceServerGrpcClient::Create(&triton_client, fTritonURL, fTritonVerbose);
59  if (!err.IsOk()) {
60  throw cet::exception("PointIdAlgTriton")
61  << "error: unable to create client for inference: " << err << std::endl;
62  }
63 
64  // ... Get the model metadata and config information
66  if (!err.IsOk()) {
67  throw cet::exception("PointIdAlgTriton")
68  << "error: failed to get model metadata: " << err << std::endl;
69  }
71  if (!err.IsOk()) {
72  throw cet::exception("PointIdAlgTriton")
73  << "error: failed to get model config: " << err << std::endl;
74  }
75 
76  // ... Set up shape vector needed when creating inference input
77  triton_inpshape.push_back(1); // initialize batch_size to 1
78  triton_inpshape.push_back(triton_modmet.inputs(0).shape(1));
79  triton_inpshape.push_back(triton_modmet.inputs(0).shape(2));
80  triton_inpshape.push_back(triton_modmet.inputs(0).shape(3));
81 
82  // ... Set up Triton inference client options
83  triton_options.model_name_ = fTritonModelName;
84  triton_options.model_version_ = fTritonModelVersion;
85 
86  mf::LogInfo("PointIdAlgTriton") << "url: " << fTritonURL;
87  mf::LogInfo("PointIdAlgTriton") << "model name: " << fTritonModelName;
88  mf::LogInfo("PointIdAlgTriton") << "model version: " << fTritonModelVersion;
89  mf::LogInfo("PointIdAlgTriton") << "verbose: " << fTritonVerbose;
90 
91  mf::LogInfo("PointIdAlgTriton") << "tensorRT inference context created.";
92 
93  resizePatch();
94  }
95 
96  // ------------------------------------------------------
97  std::vector<float>
98  PointIdAlgTriton::Run(std::vector<std::vector<float>> const& inp2d) const
99  {
100  size_t nrows = inp2d.size(), ncols = inp2d.front().size();
101 
102  triton_inpshape.at(0) = 1; // set batch size
103 
104  // ~~~~ Initialize the inputs
105 
106  nic::InferInput* triton_input;
107  auto err = nic::InferInput::Create(
108  &triton_input, triton_modmet.inputs(0).name(), triton_inpshape, triton_modmet.inputs(0).datatype() );
109  if (!err.IsOk()) {
110  throw cet::exception("PointIdAlgTriton")
111  << "unable to get input: " << err << std::endl;
112  }
113  std::shared_ptr<nic::InferInput> triton_input_ptr(triton_input);
114  std::vector<nic::InferInput*> triton_inputs = {triton_input_ptr.get()};
115 
116  // ~~~~ Register the mem address of 1st byte of image and #bytes in image
117 
118  err = triton_input_ptr->Reset();
119  if (!err.IsOk()) {
120  throw cet::exception("PointIdAlgTriton")
121  << "failed resetting Triton model input: " << err << std::endl;
122  }
123 
124  size_t sbuff_byte_size = (nrows * ncols) * sizeof(float);
125  std::vector<float> fa(sbuff_byte_size);
126 
127  // ..flatten the 2d array into contiguous 1d block
128  for (size_t ir = 0; ir < nrows; ++ir) {
129  std::copy(inp2d[ir].begin(), inp2d[ir].end(), fa.begin() + (ir * ncols));
130  }
131  err = triton_input_ptr->AppendRaw(reinterpret_cast<uint8_t*>(fa.data()), sbuff_byte_size);
132  if (!err.IsOk()) {
133  throw cet::exception("PointIdAlgTriton") << "failed setting Triton input: " << err << std::endl;
134  }
135 
136  // ~~~~ Send inference request
137 
138  nic::InferResult* results;
139 
140  err = triton_client->Infer(&results, triton_options, triton_inputs);
141  if (!err.IsOk()) {
142  throw cet::exception("PointIdAlgTriton")
143  << "failed sending Triton synchronous infer request: " << err << std::endl;
144  }
145  std::shared_ptr<nic::InferResult> results_ptr;
146  results_ptr.reset(results);
147 
148  // ~~~~ Retrieve inference results
149 
150  std::vector<float> out;
151 
152  const float *prb0;
153  size_t rbuff0_byte_size; // size of result buffer in bytes
154  results_ptr->RawData(triton_modmet.outputs(0).name(), (const uint8_t**)&prb0, &rbuff0_byte_size);
155  size_t ncat0 = rbuff0_byte_size/sizeof(float);
156 
157  const float *prb1;
158  size_t rbuff1_byte_size; // size of result buffer in bytes
159  results_ptr->RawData(triton_modmet.outputs(1).name(), (const uint8_t**)&prb1, &rbuff1_byte_size);
160  size_t ncat1 = rbuff1_byte_size/sizeof(float);
161 
162  for(unsigned j = 0; j < ncat0; j++) out.push_back(*(prb0 + j ));
163  for(unsigned j = 0; j < ncat1; j++) out.push_back(*(prb1 + j ));
164 
165  return out;
166  }
167 
168  // ------------------------------------------------------
169  std::vector<std::vector<float>>
170  PointIdAlgTriton::Run(std::vector<std::vector<std::vector<float>>> const& inps, int samples) const
171  {
172  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
173  return std::vector<std::vector<float>>();
174  }
175 
176  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
177 
178  size_t usamples = samples;
179  size_t nrows = inps.front().size(), ncols = inps.front().front().size();
180 
181  triton_inpshape.at(0) = usamples; // set batch size
182 
183  // ~~~~ Initialize the inputs
184 
185  nic::InferInput* triton_input;
186  auto err = nic::InferInput::Create(
187  &triton_input, triton_modmet.inputs(0).name(), triton_inpshape, triton_modmet.inputs(0).datatype() );
188  if (!err.IsOk()) {
189  throw cet::exception("PointIdAlgTriton")
190  << "unable to get input: " << err << std::endl;
191  }
192  std::shared_ptr<nic::InferInput> triton_input_ptr(triton_input);
193  std::vector<nic::InferInput*> triton_inputs = {triton_input_ptr.get()};
194 
195  // ~~~~ For each sample, register the mem address of 1st byte of image and #bytes in image
196  err = triton_input_ptr->Reset();
197  if (!err.IsOk()) {
198  throw cet::exception("PointIdAlgTriton")
199  << "failed resetting Triton model input: " << err << std::endl;
200  }
201 
202  size_t sbuff_byte_size = (nrows * ncols) * sizeof(float);
203  std::vector<std::vector<float>> fa(usamples, std::vector<float>(sbuff_byte_size));
204 
205  for (size_t idx = 0; idx < usamples; ++idx) {
206  // ..first flatten the 2d array into contiguous 1d block
207  for (size_t ir = 0; ir < nrows; ++ir) {
208  std::copy(inps[idx][ir].begin(), inps[idx][ir].end(), fa[idx].begin() + (ir * ncols));
209  }
210  err = triton_input_ptr->AppendRaw(reinterpret_cast<uint8_t*>(fa[idx].data()), sbuff_byte_size);
211  if (!err.IsOk()) {
212  throw cet::exception("PointIdAlgTriton")
213  << "failed setting Triton input: " << err << std::endl;
214  }
215  }
216 
217  // ~~~~ Send inference request
218 
219  nic::InferResult* results;
220 
221  err = triton_client->Infer(&results, triton_options, triton_inputs);
222  if (!err.IsOk()) {
223  throw cet::exception("PointIdAlgTriton")
224  << "failed sending Triton synchronous infer request: " << err << std::endl;
225  }
226  std::shared_ptr<nic::InferResult> results_ptr;
227  results_ptr.reset(results);
228 
229  // ~~~~ Retrieve inference results
230 
231  std::vector<std::vector<float>> out;
232 
233  const float *prb0;
234  size_t rbuff0_byte_size; // size of result buffer in bytes
235  results_ptr->RawData(triton_modmet.outputs(0).name(), (const uint8_t**)&prb0, &rbuff0_byte_size);
236  size_t ncat0 = rbuff0_byte_size/(usamples*sizeof(float));
237 
238  const float *prb1;
239  size_t rbuff1_byte_size; // size of result buffer in bytes
240  results_ptr->RawData(triton_modmet.outputs(1).name(), (const uint8_t**)&prb1, &rbuff1_byte_size);
241  size_t ncat1 = rbuff1_byte_size/(usamples*sizeof(float));
242 
243  for(unsigned i = 0; i < usamples; i++) {
244  std::vector<float> vprb;
245  for(unsigned j = 0; j < ncat0; j++) vprb.push_back(*(prb0 + i*ncat0 + j ));
246  for(unsigned j = 0; j < ncat1; j++) vprb.push_back(*(prb1 + i*ncat1 + j ));
247  out.push_back(vprb);
248  }
249 
250  return out;
251  }
252 
253 }
end
while True: pbar.update(maxval-len(onlies[E][S])) #print iS, "/", len(onlies[E][S]) found = False for...
PointIdAlgTriton(fhicl::Table< Config > const &table)
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
std::string string
Definition: nybbler.cc:12
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
struct vector vector
DataProviderAlg(const fhicl::ParameterSet &pset)
std::vector< float > Run(std::vector< std::vector< float >> const &inp2d) const override
inference::ModelMetadataResponse triton_modmet
std::vector< std::string > fNNetOutputs
Definition: IPointIdAlg.h:152
inference::ModelConfigResponse triton_modcfg
void err(const char *fmt,...)
Definition: message.cpp:226
std::unique_ptr< nic::InferenceServerGrpcClient > triton_client
T copy(T const &v)
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:72
unsigned nrows(sqlite3 *db, std::string const &tablename)
Definition: helpers.cc:82
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)