TritonData.cc
Go to the documentation of this file.
4 
5 #include "model_config.pb.h"
6 
7 #include <cstring>
8 #include <sstream>
9 
10 namespace ni = nvidia::inferenceserver;
11 namespace nic = ni::client;
12 
13 namespace nvidia {
14  namespace inferenceserver {
15  //in libgrpcclient.so, but corresponding header src/core/model_config.h not available
16  size_t GetDataTypeByteSize(const inference::DataType dtype);
17  inference::DataType ProtocolStringToDataType(const std::string& dtype);
18  } // namespace inferenceserver
19 } // namespace nvidia
20 
21 namespace lartriton {
22 
23 //dims: kept constant, represents config.pbtxt parameters of model (converted from google::protobuf::RepeatedField to vector)
24 //fullShape: if batching is enabled, first entry is batch size; values can be modified
25 //shape: view into fullShape, excluding batch size entry
26 template <typename IO>
28  : name_(name),
29  dims_(model_info.shape().begin(), model_info.shape().end()),
30  noBatch_(noBatch),
31  batchSize_(0),
32  fullShape_(dims_),
33  shape_(fullShape_.begin() + (noBatch_ ? 0 : 1), fullShape_.end()),
34  variableDims_(anyNeg(shape_)),
35  productDims_(variableDims_ ? -1 : dimProduct(shape_)),
36  dname_(model_info.datatype()),
37  dtype_(ni::ProtocolStringToDataType(dname_)),
38  byteSize_(ni::GetDataTypeByteSize(dtype_)) {
39  //create input or output object
40  IO* iotmp;
41  createObject(&iotmp);
42  data_.reset(iotmp);
43 }
44 
45 template <>
46 void TritonInputData::createObject(nic::InferInput** ioptr) const {
47  nic::InferInput::Create(ioptr, name_, fullShape_, dname_);
48 }
49 
50 template <>
51 void TritonOutputData::createObject(nic::InferRequestedOutput** ioptr) const {
52  nic::InferRequestedOutput::Create(ioptr, name_);
53 }
54 
55 //setters
56 template <typename IO>
57 bool TritonData<IO>::setShape(const TritonData<IO>::ShapeType& newShape, bool canThrow) {
58  bool result = true;
59  for (unsigned i = 0; i < newShape.size(); ++i) {
60  result &= setShape(i, newShape[i], canThrow);
61  }
62  return result;
63 }
64 
65 template <typename IO>
66 bool TritonData<IO>::setShape(unsigned loc, int64_t val, bool canThrow) {
67  std::stringstream msg;
68  unsigned full_loc = loc + (noBatch_ ? 0 : 1);
69 
70  //check boundary
71  if (full_loc >= fullShape_.size()) {
72  msg << name_ << " setShape(): dimension " << full_loc << " out of bounds (" << fullShape_.size() << ")";
73  if (canThrow)
74  throw cet::exception("TritonDataError") << msg.str();
75  else {
76  MF_LOG_WARNING("TritonDataWarning") << msg.str();
77  return false;
78  }
79  }
80 
81  if (val != fullShape_[full_loc]) {
82  if (dims_[full_loc] == -1) {
83  fullShape_[full_loc] = val;
84  return true;
85  } else {
86  msg << name_ << " setShape(): attempt to change value of non-variable shape dimension " << loc;
87  if (canThrow)
88  throw cet::exception("TritonDataError") << msg.str();
89  else {
90  MF_LOG_WARNING("TritonDataError") << msg.str();
91  return false;
92  }
93  }
94  }
95 
96  return true;
97 }
98 
99 template <typename IO>
100 void TritonData<IO>::setBatchSize(unsigned bsize) {
101  batchSize_ = bsize;
102  if (!noBatch_)
103  fullShape_[0] = batchSize_;
104 }
105 
106 //io accessors
107 template <>
108 template <typename DT>
109 void TritonInputData::toServer(std::shared_ptr<TritonInput<DT>> ptr) {
110  const auto& data_in = *ptr;
111 
112  //check batch size
113  if (data_in.size() != batchSize_) {
114  throw cet::exception("TritonDataError") << name_ << " input(): input vector has size " << data_in.size()
115  << " but specified batch size is " << batchSize_;
116  }
117 
118  //shape must be specified for variable dims or if batch size changes
119  data_->SetShape(fullShape_);
120 
121  if (byteSize_ != sizeof(DT))
122  throw cet::exception("TritonDataError") << name_ << " input(): inconsistent byte size " << sizeof(DT)
123  << " (should be " << byteSize_ << " for " << dname_ << ")";
124 
125  int64_t nInput = sizeShape();
126  for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
127  const DT* arr = data_in[i0].data();
128  triton_utils::throwIfError(data_->AppendRaw(reinterpret_cast<const uint8_t*>(arr), nInput * byteSize_),
129  name_ + " input(): unable to set data for batch entry " + std::to_string(i0));
130  }
131 
132  //keep input data in scope
133  holder_ = std::move(ptr);
134 }
135 
136 template <>
137 template <typename DT>
139  if (!result_) {
140  throw cet::exception("TritonDataError") << name_ << " output(): missing result";
141  }
142 
143  if (byteSize_ != sizeof(DT)) {
144  throw cet::exception("TritonDataError") << name_ << " output(): inconsistent byte size " << sizeof(DT)
145  << " (should be " << byteSize_ << " for " << dname_ << ")";
146  }
147 
148  uint64_t nOutput = sizeShape();
149  TritonOutput<DT> dataOut;
150  const uint8_t* r0;
151  size_t contentByteSize;
152  size_t expectedContentByteSize = nOutput * byteSize_ * batchSize_;
153  triton_utils::throwIfError(result_->RawData(name_, &r0, &contentByteSize), "output(): unable to get raw");
154  if (contentByteSize != expectedContentByteSize) {
155  throw cet::exception("TritonDataError") << name_ << " output(): unexpected content byte size " << contentByteSize
156  << " (expected " << expectedContentByteSize << ")";
157  }
158 
159  const DT* r1 = reinterpret_cast<const DT*>(r0);
160  dataOut.reserve(batchSize_);
161  for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
162  auto offset = i0 * nOutput;
163  dataOut.emplace_back(r1 + offset, r1 + offset + nOutput);
164  }
165 
166  return dataOut;
167 }
168 
169 template <>
171  data_->Reset();
172  holder_.reset();
173 }
174 
175 template <>
177  result_.reset();
178 }
179 
180 //explicit template instantiation declarations
181 template class TritonData<nic::InferInput>;
183 
184 template void TritonInputData::toServer(std::shared_ptr<TritonInput<float>> data_in);
185 template void TritonInputData::toServer(std::shared_ptr<TritonInput<int64_t>> data_in);
186 
188 
189 }
static QCString name
Definition: declinfo.cpp:673
end
while True: pbar.update(maxval-len(onlies[E][S])) #print iS, "/", len(onlies[E][S]) found = False for...
std::string name_
Definition: TritonData.h:84
static QCString result
void setBatchSize(unsigned bsize)
Definition: TritonData.cc:100
bool setShape(const ShapeType &newShape)
Definition: TritonData.h:42
void msg(const char *fmt,...)
Definition: message.cpp:107
const ShapeType dims_
Definition: TritonData.h:86
std::string string
Definition: nybbler.cc:12
ShapeType fullShape_
Definition: TritonData.h:89
std::vector< int64_t > ShapeType
Definition: TritonData.h:35
void throwIfError(const Error &err, std::string_view msg)
Definition: triton_utils.cc:22
void toServer(std::shared_ptr< TritonInput< DT >> ptr)
Definition: TritonData.cc:109
std::vector< std::vector< DT >> TritonInput
Definition: TritonData.h:25
def move(depos, offset)
Definition: depos.py:107
int64_t sizeShape() const
Definition: TritonData.h:61
void createObject(IO **ioptr) const
std::shared_ptr< Result > result_
Definition: TritonData.h:97
std::vector< triton_span::Span< const DT * >> TritonOutput
Definition: TritonData.h:27
decltype(auto) constexpr begin(T &&obj)
ADL-aware version of std::begin.
Definition: StdUtils.h:72
std::shared_ptr< IO > data_
Definition: TritonData.h:85
inference::ModelMetadataResponse_TensorMetadata TensorMetadata
Definition: TritonData.h:34
#define MF_LOG_WARNING(category)
std::string dname_
Definition: TritonData.h:93
TritonOutput< DT > fromServer() const
Definition: TritonData.cc:138
inference::DataType ProtocolStringToDataType(const std::string &dtype)
std::string to_string(ModuleType const mt)
Definition: ModuleType.h:34
size_t GetDataTypeByteSize(const inference::DataType dtype)
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33