TritonClient.cc
Go to the documentation of this file.
6 
7 #include "grpc_client.h"
8 #include "grpc_service.pb.h"
9 
10 #include <string>
11 #include <cmath>
12 #include <chrono>
13 #include <exception>
14 #include <sstream>
15 #include <utility>
16 #include <tuple>
17 
18 
19 namespace ni = nvidia::inferenceserver;
20 namespace nic = ni::client;
21 
22 //based on https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/examples/simple_grpc_async_infer_client.cc
23 //and https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/perf_client.cc
24 
25 namespace lartriton {
26 
28  : allowedTries_(params.get<unsigned>("allowedTries", 0)),
29  serverURL_(params.get<std::string>("serverURL")),
30  verbose_(params.get<bool>("verbose")),
31  options_(params.get<std::string>("modelName")) {
32  //get appropriate server for this model
33  if (verbose_)
34  MF_LOG_INFO("TritonClient") << "Using server: " << serverURL_;
35 
36  //connect to the server
37  //TODO: add SSL options
38  triton_utils::throwIfError(nic::InferenceServerGrpcClient::Create(&client_, serverURL_, false),
39  "TritonClient(): unable to create inference context");
40 
41  //set options
42  options_.model_version_ = params.get<std::string>("modelVersion");
43  //convert seconds to microseconds
44  options_.client_timeout_ = params.get<unsigned>("timeout") * 1e6;
45 
46  //config needed for batch size
47  inference::ModelConfigResponse modelConfigResponse;
48  triton_utils::throwIfError(client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
49  "TritonClient(): unable to get model config");
50  inference::ModelConfig modelConfig(modelConfigResponse.config());
51 
52  //check batch size limitations (after i/o setup)
53  //triton uses max batch size = 0 to denote a model that does not support batching
54  //but for models that do support batching, a given event may set batch size 0 to indicate no valid input is present
55  //so set the local max to 1 and keep track of "no batch" case
56  maxBatchSize_ = modelConfig.max_batch_size();
57  noBatch_ = maxBatchSize_ == 0;
58  maxBatchSize_ = std::max(1u, maxBatchSize_);
59 
60  //get model info
61  inference::ModelMetadataResponse modelMetadata;
62  triton_utils::throwIfError(client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
63  "TritonClient(): unable to get model metadata");
64 
65  //get input and output (which know their sizes)
66  const auto& nicInputs = modelMetadata.inputs();
67  const auto& nicOutputs = modelMetadata.outputs();
68 
69  //report all model errors at once
70  std::ostringstream msg;
71  std::string msg_str;
72 
73  //currently no use case is foreseen for a model with zero inputs or outputs
74  if (nicInputs.empty())
75  msg << "Model on server appears malformed (zero inputs)\n";
76 
77  if (nicOutputs.empty())
78  msg << "Model on server appears malformed (zero outputs)\n";
79 
80  //stop if errors
81  msg_str = msg.str();
82  if (!msg_str.empty())
83  throw cet::exception("ModelErrors") << msg_str;
84 
85  //setup input map
86  std::ostringstream io_msg;
87  if (verbose_)
88  io_msg << "Model inputs: "
89  << "\n";
90  inputsTriton_.reserve(nicInputs.size());
91  for (const auto& nicInput : nicInputs) {
92  const auto& iname = nicInput.name();
93  auto [curr_itr, success] = input_.try_emplace(iname, iname, nicInput, noBatch_);
94  auto& curr_input = curr_itr->second;
95  inputsTriton_.push_back(curr_input.data());
96  if (verbose_) {
97  io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
98  << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
99  }
100  }
101 
102  //allow selecting only some outputs from server
103  const auto& v_outputs = params.get<std::vector<std::string>>("outputs");
104  std::unordered_set<std::string> s_outputs(v_outputs.begin(), v_outputs.end());
105 
106  //setup output map
107  if (verbose_)
108  io_msg << "Model outputs: "
109  << "\n";
110  outputsTriton_.reserve(nicOutputs.size());
111  for (const auto& nicOutput : nicOutputs) {
112  const auto& oname = nicOutput.name();
113  if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
114  continue;
115  auto [curr_itr, success] = output_.try_emplace(oname, oname, nicOutput, noBatch_);
116  auto& curr_output = curr_itr->second;
117  outputsTriton_.push_back(curr_output.data());
118  if (verbose_) {
119  io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
120  << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
121  }
122  if (!s_outputs.empty())
123  s_outputs.erase(oname);
124  }
125 
126  //check if any requested outputs were not available
127  if (!s_outputs.empty())
128  throw cet::exception("MissingOutput")
129  << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
130 
131  //propagate batch size to inputs and outputs
132  setBatchSize(1);
133 
134  //print model info
135  if (verbose_) {
136  std::ostringstream model_msg;
137  model_msg << "Model name: " << options_.model_name_ << "\n"
138  << "Model version: " << options_.model_version_ << "\n"
139  << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
140  MF_LOG_INFO("TritonClient") << model_msg.str() << io_msg.str();
141  }
142 }
143 
144 bool TritonClient::setBatchSize(unsigned bsize) {
145  if (bsize > maxBatchSize_) {
146  MF_LOG_WARNING("TritonClient") << "Requested batch size " << bsize << " exceeds server-specified max batch size "
147  << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
148  return false;
149  }
150  batchSize_ = bsize;
151  //set for input and output
152  for (auto& element : input_) {
153  element.second.setBatchSize(bsize);
154  }
155  for (auto& element : output_) {
156  element.second.setBatchSize(bsize);
157  }
158  return true;
159 }
160 
162  for (auto& element : input_) {
163  element.second.reset();
164  }
165  for (auto& element : output_) {
166  element.second.reset();
167  }
168 }
169 
170 bool TritonClient::getResults(std::shared_ptr<nic::InferResult> results) {
171  for (auto& [oname, output] : output_) {
172  //set shape here before output becomes const
173  if (output.variableDims()) {
174  std::vector<int64_t> tmp_shape;
175  bool status = triton_utils::warnIfError(results->Shape(oname, &tmp_shape),
176  "getResults(): unable to get output shape for " + oname);
177  if (!status)
178  return status;
179  output.setShape(tmp_shape, false);
180  }
181  //extend lifetime
182  output.setResult(results);
183  }
184 
185  return true;
186 }
187 
189  tries_ = 0;
190 }
191 
192 //default case for sync and pseudo async
194  //in case there is nothing to process
195  if (batchSize_ == 0) {
196  finish(true);
197  return;
198  }
199 
200  // Get the status of the server prior to the request being made.
201  const auto& start_status = getServerSideStatus();
202 
203  //blocking call
205  nic::InferResult* results;
207  "evaluate(): unable to run and/or get result");
208  if (!status) {
209  finish(false);
210  return;
211  }
212 
213  auto t2 = std::chrono::steady_clock::now();
214  MF_LOG_DEBUG("TritonClient") << "Remote time: "
215  << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
216 
217  const auto& end_status = getServerSideStatus();
218 
219  if (verbose()) {
220  const auto& stats = summarizeServerStats(start_status, end_status);
221  reportServerSideStats(stats);
222  }
223 
224  std::shared_ptr<nic::InferResult> results_ptr(results);
225  status = getResults(results_ptr);
226 
227  //status = getResults(std::make_shared<nvidia::inferenceserver::client::InferResult> (results));
228  //status = getResults(std::make_shared<nic::InferResult> (results));
229 
230  finish(status);
231 }
232 
233 void TritonClient::finish(bool success) {
234  //retries are only allowed if no exception was raised
235  if (!success) {
236  ++tries_;
237  //if max retries has not been exceeded, call evaluate again
238  if (tries_ < allowedTries_) {
239  evaluate();
240  //avoid calling doneWaiting() twice
241  return;
242  }
243  //prepare an exception if exceeded
244  throw cet::exception("TritonClient")
245  << "call failed after max " << tries_ << " tries" << std::endl;
246  }
247 }
248 
250  std::ostringstream msg;
251 
252  // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc
253  const uint64_t count = stats.success_count_;
254  msg << " Inference count: " << stats.inference_count_ << "\n";
255  msg << " Execution count: " << stats.execution_count_ << "\n";
256  msg << " Successful request count: " << count << "\n";
257 
258  if (count > 0) {
259  auto get_avg_us = [count](uint64_t tval) {
260  constexpr uint64_t us_to_ns = 1000;
261  return tval / us_to_ns / count;
262  };
263 
264  const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
265  const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
266  const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
267  const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
268  const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
269  const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
270  const uint64_t overhead =
271  (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
272 
273  msg << " Avg request latency: " << cumm_avg_us << " usec"
274  << "\n"
275  << " (overhead " << overhead << " usec + "
276  << "queue " << queue_avg_us << " usec + "
277  << "compute input " << compute_input_avg_us << " usec + "
278  << "compute infer " << compute_infer_avg_us << " usec + "
279  << "compute output " << compute_output_avg_us << " usec)" << std::endl;
280  }
281 
282  MF_LOG_DEBUG("TritonClient") << msg.str();
283 }
284 
285 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
286  const inference::ModelStatistics& end_status) const {
287  TritonClient::ServerSideStats server_stats;
288 
289  server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
290  server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
291  server_stats.success_count_ =
292  end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
293  server_stats.cumm_time_ns_ =
294  end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
295  server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
296  server_stats.compute_input_time_ns_ =
297  end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
298  server_stats.compute_infer_time_ns_ =
299  end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
300  server_stats.compute_output_time_ns_ =
301  end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
302 
303  return server_stats;
304 }
305 
306 inference::ModelStatistics TritonClient::getServerSideStatus() const {
307  if (verbose_) {
308  inference::ModelStatisticsResponse resp;
309  bool success = triton_utils::warnIfError(
310  client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
311  "getServerSideStatus(): unable to get model statistics");
312  if (success)
313  return *(resp.model_stats().begin());
314  }
315  return inference::ModelStatistics{};
316 }
317 
318 }
std::vector< const nvidia::inferenceserver::client::InferRequestedOutput * > outputsTriton_
Definition: TritonClient.h:76
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
nvidia::inferenceserver::client::InferOptions options_
Definition: TritonClient.h:80
void msg(const char *fmt,...)
Definition: message.cpp:107
std::string printColl(const C &coll, const std::string &delim)
Definition: triton_utils.cc:11
void reportServerSideStats(const ServerSideStats &stats) const
std::string string
Definition: nybbler.cc:12
const TritonOutputMap & output() const
Definition: TritonClient.h:36
bool warnIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:27
std::unique_ptr< nvidia::inferenceserver::client::InferenceServerGrpcClient > client_
Definition: TritonClient.h:78
STL namespace.
microsecond microseconds
Alias for common language habits.
Definition: spacetime.h:122
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:22
bool getResults(std::shared_ptr< nvidia::inferenceserver::client::InferResult > results)
void finish(bool success)
TritonOutputMap output_
Definition: TritonClient.h:66
inference::ModelStatistics getServerSideStatus() const
T get(std::string const &key) const
Definition: ParameterSet.h:271
static int max(int a, int b)
TritonInputMap input_
Definition: TritonClient.h:65
#define MF_LOG_INFO(category)
bool setBatchSize(unsigned bsize)
bool verbose() const
Definition: TritonClient.h:38
#define MF_LOG_DEBUG(id)
TritonClient(const fhicl::ParameterSet &params)
Definition: TritonClient.cc:27
std::vector< nvidia::inferenceserver::client::InferInput * > inputsTriton_
Definition: TritonClient.h:75
auto const & get(AssnsNode< L, R, D > const &r)
Definition: AssnsNode.h:115
#define MF_LOG_WARNING(category)
int bool
Definition: qglobal.h:345
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)