7 #include "grpc_client.h" 8 #include "grpc_service.pb.h" 20 namespace nic = ni::client;
28 : allowedTries_(params.
get<unsigned>(
"allowedTries", 0)),
30 verbose_(params.
get<
bool>(
"verbose")),
39 "TritonClient(): unable to create inference context");
44 options_.client_timeout_ = params.
get<
unsigned>(
"timeout") * 1e6;
47 inference::ModelConfigResponse modelConfigResponse;
49 "TritonClient(): unable to get model config");
50 inference::ModelConfig modelConfig(modelConfigResponse.config());
58 maxBatchSize_ =
std::max(1u, maxBatchSize_);
61 inference::ModelMetadataResponse modelMetadata;
63 "TritonClient(): unable to get model metadata");
66 const auto& nicInputs = modelMetadata.inputs();
67 const auto& nicOutputs = modelMetadata.outputs();
70 std::ostringstream
msg;
74 if (nicInputs.empty())
75 msg <<
"Model on server appears malformed (zero inputs)\n";
77 if (nicOutputs.empty())
78 msg <<
"Model on server appears malformed (zero outputs)\n";
86 std::ostringstream io_msg;
88 io_msg <<
"Model inputs: " 91 for (
const auto& nicInput : nicInputs) {
92 const auto& iname = nicInput.name();
93 auto [curr_itr, success] =
input_.try_emplace(iname, iname, nicInput,
noBatch_);
94 auto& curr_input = curr_itr->second;
97 io_msg <<
" " << iname <<
" (" << curr_input.dname() <<
", " << curr_input.byteSize()
103 const auto& v_outputs = params.
get<std::vector<std::string>>(
"outputs");
104 std::unordered_set<std::string> s_outputs(v_outputs.begin(), v_outputs.end());
108 io_msg <<
"Model outputs: " 111 for (
const auto& nicOutput : nicOutputs) {
112 const auto& oname = nicOutput.name();
113 if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
115 auto [curr_itr, success] =
output_.try_emplace(oname, oname, nicOutput,
noBatch_);
116 auto& curr_output = curr_itr->second;
119 io_msg <<
" " << oname <<
" (" << curr_output.dname() <<
", " << curr_output.byteSize()
122 if (!s_outputs.empty())
123 s_outputs.erase(oname);
127 if (!s_outputs.empty())
136 std::ostringstream model_msg;
137 model_msg <<
"Model name: " <<
options_.model_name_ <<
"\n" 138 <<
"Model version: " <<
options_.model_version_ <<
"\n" 140 MF_LOG_INFO(
"TritonClient") << model_msg.str() << io_msg.str();
146 MF_LOG_WARNING(
"TritonClient") <<
"Requested batch size " << bsize <<
" exceeds server-specified max batch size " 152 for (
auto& element :
input_) {
153 element.second.setBatchSize(bsize);
155 for (
auto& element :
output_) {
156 element.second.setBatchSize(bsize);
162 for (
auto& element :
input_) {
163 element.second.reset();
165 for (
auto& element :
output_) {
166 element.second.reset();
173 if (
output.variableDims()) {
174 std::vector<int64_t> tmp_shape;
176 "getResults(): unable to get output shape for " + oname);
179 output.setShape(tmp_shape,
false);
182 output.setResult(results);
205 nic::InferResult* results;
207 "evaluate(): unable to run and/or get result");
224 std::shared_ptr<nic::InferResult> results_ptr(results);
250 std::ostringstream
msg;
256 msg <<
" Successful request count: " << count <<
"\n";
259 auto get_avg_us = [
count](uint64_t tval) {
260 constexpr uint64_t us_to_ns = 1000;
261 return tval / us_to_ns /
count;
264 const uint64_t cumm_avg_us = get_avg_us(stats.
cumm_time_ns_);
269 const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
270 const uint64_t overhead =
271 (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
273 msg <<
" Avg request latency: " << cumm_avg_us <<
" usec" 275 <<
" (overhead " << overhead <<
" usec + " 276 <<
"queue " << queue_avg_us <<
" usec + " 277 <<
"compute input " << compute_input_avg_us <<
" usec + " 278 <<
"compute infer " << compute_infer_avg_us <<
" usec + " 279 <<
"compute output " << compute_output_avg_us <<
" usec)" <<
std::endl;
286 const inference::ModelStatistics& end_status)
const {
289 server_stats.
inference_count_ = end_status.inference_count() - start_status.inference_count();
290 server_stats.
execution_count_ = end_status.execution_count() - start_status.execution_count();
292 end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
294 end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
295 server_stats.
queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
297 end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
299 end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
301 end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
308 inference::ModelStatisticsResponse resp;
311 "getServerSideStatus(): unable to get model statistics");
313 return *(resp.model_stats().begin());
315 return inference::ModelStatistics{};
std::vector< const nvidia::inferenceserver::client::InferRequestedOutput * > outputsTriton_
ServerSideStats summarizeServerStats(const inference::ModelStatistics &start_status, const inference::ModelStatistics &end_status) const
uint64_t compute_output_time_ns_
nvidia::inferenceserver::client::InferOptions options_
void msg(const char *fmt,...)
std::string printColl(const C &coll, const std::string &delim)
void reportServerSideStats(const ServerSideStats &stats) const
const TritonOutputMap & output() const
bool warnIfError(const Error &err, std::string_view msg)
uint64_t compute_infer_time_ns_
uint64_t compute_input_time_ns_
std::unique_ptr< nvidia::inferenceserver::client::InferenceServerGrpcClient > client_
microsecond microseconds
Alias for common language habits.
void throwIfError(const Error &err, std::string_view msg)
bool getResults(std::shared_ptr< nvidia::inferenceserver::client::InferResult > results)
void finish(bool success)
inference::ModelStatistics getServerSideStatus() const
uint64_t inference_count_
T get(std::string const &key) const
static int max(int a, int b)
#define MF_LOG_INFO(category)
bool setBatchSize(unsigned bsize)
TritonClient(const fhicl::ParameterSet ¶ms)
uint64_t execution_count_
std::vector< nvidia::inferenceserver::client::InferInput * > inputsTriton_
auto const & get(AssnsNode< L, R, D > const &r)
#define MF_LOG_WARNING(category)
cet::coded_exception< error, detail::translate > exception
QTextStream & endl(QTextStream &s)