PointIdAlgSonicTriton_tool.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////////////////////////////////
2 // Class: PointIdAlgSonicTriton_tool
3 // Authors: M.Wang
4 ////////////////////////////////////////////////////////////////////////////////////////////////////
5 
9 
10 namespace PointIdAlgTools {
11 
13  public:
14  explicit PointIdAlgSonicTriton(fhicl::Table<Config> const& table);
15 
16  std::vector<float> Run(std::vector<std::vector<float>> const& inp2d) const override;
17  std::vector<std::vector<float>> Run(std::vector<std::vector<std::vector<float>>> const& inps,
18  int samples = -1) const override;
19 
20  private:
25  unsigned fTritonTimeout;
27 
28  std::unique_ptr<lartriton::TritonClient> triton_client;
29  };
30 
31  // ------------------------------------------------------
33  : img::DataProviderAlg(table())
34  {
35  // ... Get common config vars
36  fNNetOutputs = table().NNetOutputs();
37  fPatchSizeW = table().PatchSizeW();
38  fPatchSizeD = table().PatchSizeD();
39  fCurrentWireIdx = 99999;
40  fCurrentScaledDrift = 99999;
41 
42  // ... Get "optional" config vars specific to tRTis interface
43  fTritonModelName = table().TritonModelName();
44  fTritonURL = table().TritonURL();
45  fTritonVerbose = table().TritonVerbose();
46  fTritonModelVersion = table().TritonModelVersion();
47  fTritonAllowedTries = table().TritonAllowedTries();
48 
49  // ... Create parameter set for Triton inference client
50  fhicl::ParameterSet TritonPset;
51  TritonPset.put("serverURL",fTritonURL);
52  TritonPset.put("verbose",fTritonVerbose);
53  TritonPset.put("modelName",fTritonModelName);
54  TritonPset.put("modelVersion",fTritonModelVersion);
55  TritonPset.put("timeout",fTritonTimeout);
56  TritonPset.put("allowedTries",fTritonAllowedTries);
57  TritonPset.put("outputs","[]");
58 
59  // ... Create the Triton inference client
60  triton_client = std::make_unique<lartriton::TritonClient>(TritonPset);
61 
62  mf::LogInfo("PointIdAlgSonicTriton") << "url: " << fTritonURL;
63  mf::LogInfo("PointIdAlgSonicTriton") << "model name: " << fTritonModelName;
64  mf::LogInfo("PointIdAlgSonicTriton") << "model version: " << fTritonModelVersion;
65  mf::LogInfo("PointIdAlgSonicTriton") << "verbose: " << fTritonVerbose;
66 
67  mf::LogInfo("PointIdAlgSonicTriton") << "tensorRT inference context created.";
68 
69  resizePatch();
70  }
71 
72  // ------------------------------------------------------
73  std::vector<float>
74  PointIdAlgSonicTriton::Run(std::vector<std::vector<float>> const& inp2d) const
75  {
76  size_t nrows = inp2d.size();
77 
78  triton_client->setBatchSize(1); // set batch size
79 
80  // ~~~~ Initialize the inputs
81  auto& triton_input = triton_client->input().begin()->second;
82 
83  auto data1 = std::make_shared<lartriton::TritonInput<float>>();
84  data1->reserve(1);
85 
86  // ~~~~ Prepare image for sending to server
87  auto& img = data1->emplace_back();
88  // ..first flatten the 2d array into contiguous 1d block
89  for (size_t ir = 0; ir < nrows; ++ir) {
90  img.insert(img.end(), inp2d[ir].begin(), inp2d[ir].end());
91  }
92 
93  triton_input.toServer(data1); // convert to server format
94 
95  // ~~~~ Send inference request
96  triton_client->dispatch();
97 
98  // ~~~~ Retrieve inference results
99  const auto& triton_output0 = triton_client->output().at("em_trk_none_netout/Softmax");
100  const auto& prob0 = triton_output0.fromServer<float>();
101  auto ncat0 = triton_output0.sizeDims();
102 
103  const auto& triton_output1 = triton_client->output().at("michel_netout/Sigmoid");
104  const auto& prob1 = triton_output1.fromServer<float>();
105  auto ncat1 = triton_output1.sizeDims();
106 
107  std::vector<float> out;
108  out.reserve(ncat0+ncat1);
109  out.insert(out.end(), prob0[0].begin(), prob0[0].end());
110  out.insert(out.end(), prob1[0].begin(), prob1[0].end());
111 
112  triton_client->reset();
113 
114  return out;
115  }
116 
117  // ------------------------------------------------------
118  std::vector<std::vector<float>>
119  PointIdAlgSonicTriton::Run(std::vector<std::vector<std::vector<float>>> const& inps, int samples) const
120  {
121  if ((samples == 0) || inps.empty() || inps.front().empty() || inps.front().front().empty()) {
122  return std::vector<std::vector<float>>();
123  }
124 
125  if ((samples == -1) || (samples > (long long int)inps.size())) { samples = inps.size(); }
126 
127  size_t usamples = samples;
128  size_t nrows = inps.front().size();
129 
130  triton_client->setBatchSize(usamples); // set batch size
131 
132  // ~~~~ Initialize the inputs
133  auto& triton_input = triton_client->input().begin()->second;
134 
135  auto data1 = std::make_shared<lartriton::TritonInput<float>>();
136  data1->reserve(usamples);
137 
138  // ~~~~ For each sample, prepare images for sending to server
139  for (size_t idx = 0; idx < usamples; ++idx) {
140  auto& img = data1->emplace_back();
141  // ..first flatten the 2d array into contiguous 1d block
142  for (size_t ir = 0; ir < nrows; ++ir) {
143  img.insert(img.end(), inps[idx][ir].begin(), inps[idx][ir].end());
144  }
145  }
146  triton_input.toServer(data1); // convert to server format
147 
148  // ~~~~ Send inference request
149  triton_client->dispatch();
150 
151  // ~~~~ Retrieve inference results
152  const auto& triton_output0 = triton_client->output().at("em_trk_none_netout/Softmax");
153  const auto& prob0 = triton_output0.fromServer<float>();
154  auto ncat0 = triton_output0.sizeDims();
155 
156  const auto& triton_output1 = triton_client->output().at("michel_netout/Sigmoid");
157  const auto& prob1 = triton_output1.fromServer<float>();
158  auto ncat1 = triton_output1.sizeDims();
159 
160  std::vector<std::vector<float>> out;
161  out.reserve(usamples);
162  for(unsigned i = 0; i < usamples; i++) {
163  out.emplace_back();
164  auto& img = out.back();
165  img.reserve(ncat0+ncat1);
166  img.insert(img.end(), prob0[i].begin(), prob0[i].end());
167  img.insert(img.end(), prob1[i].begin(), prob1[i].end());
168  }
169 
170  triton_client->reset();
171 
172  return out;
173  }
174 
175 }
#define DEFINE_ART_CLASS_TOOL(tool)
Definition: ToolMacros.h:42
std::string string
Definition: nybbler.cc:12
MaybeLogger_< ELseverityLevel::ELsev_info, false > LogInfo
struct vector vector
DataProviderAlg(const fhicl::ParameterSet &pset)
PointIdAlgSonicTriton(fhicl::Table< Config > const &table)
std::vector< float > Run(std::vector< std::vector< float >> const &inp2d) const override
std::vector< std::string > fNNetOutputs
Definition: IPointIdAlg.h:152
std::unique_ptr< lartriton::TritonClient > triton_client
unsigned nrows(sqlite3 *db, std::string const &tablename)
Definition: helpers.cc:82
void put(std::string const &key)