#include <RegCNN_TF_Graph.h>
Definition at line 28 of file RegCNN_TF_Graph.h.
tf::RegCNNGraph::~RegCNNGraph |
( |
| ) |
|
Definition at line 110 of file RegCNN_TF_Graph.cc.
114 std::cout <<
"RegCNNGraph::dtor: " <<
"Close failed." <<
std::endl;
tensorflow::Session * fSession
QTextStream & endl(QTextStream &s)
tf::RegCNNGraph::RegCNNGraph |
( |
const char * |
graph_file_name, |
|
|
const unsigned int & |
ninputs, |
|
|
const std::vector< std::string > & |
outputs, |
|
|
bool & |
success |
|
) |
| |
|
private |
Not-throwing constructor.
Definition at line 20 of file RegCNN_TF_Graph.cc.
24 auto status = tensorflow::NewSession(tensorflow::SessionOptions(), &
fSession);
31 tensorflow::GraphDef graph_def;
39 size_t ng = graph_def.node().size();
55 for (
unsigned int ii = 0; ii < ninputs; ++ii){
64 if (outputs.empty()) {
fOutputNames.push_back(graph_def.node()[ng - 1].name()); }
68 for (
size_t n = 0;
n < ng; ++
n)
70 name = graph_def.node()[
n].name();
71 auto pos = name.find(
"/");
72 if (
pos != std::string::npos) { basename = name.substr(0,
pos); }
76 for (
const auto &
s : outputs)
78 if (name.find(
s) != std::string::npos) { found =
true;
break; }
82 if (!last.empty() && (basename !=
current))
94 std::cout <<
"Output nodes not found in the graph." <<
std::endl;
107 std::cout<<
"ok, graph loaded from the file"<<
std::endl;
std::vector< std::string > fInputNames
std::vector< std::string > fOutputNames
tensorflow::Session * fSession
QTextStream & endl(QTextStream &s)
static std::unique_ptr<RegCNNGraph> tf::RegCNNGraph::create |
( |
const char * |
graph_file_name, |
|
|
const unsigned int & |
ninputs, |
|
|
const std::vector< std::string > & |
outputs = {} |
|
) |
| |
|
inlinestatic |
Definition at line 31 of file RegCNN_TF_Graph.h.
34 std::unique_ptr<RegCNNGraph> ptr(
new RegCNNGraph(graph_file_name, ninputs, outputs, success));
35 if (success) {
return ptr; }
36 else {
return nullptr; }
RegCNNGraph(const char *graph_file_name, const unsigned int &ninputs, const std::vector< std::string > &outputs, bool &success)
Not-throwing constructor.
Definition at line 120 of file RegCNN_TF_Graph.cc.
123 if (x.empty() || x.front().empty()) {
return std::vector<float>(); }
125 long long int rows = x.size(), cols = x.front().size();
127 tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({ 1, rows, cols, 1 }));
128 auto input_map = _x.tensor<
float, 4>();
130 for (
long long int r = 0;
r < rows; ++
r) {
131 const auto &
row = x[
r];
132 for (
long long int c = 0;
c < cols; ++
c) {
133 input_map(0,
r,
c, 0) =
row[
c];
139 else {
return std::vector<float>(); }
std::vector< float > run(const std::vector< std::vector< float > > &x)
Definition at line 245 of file RegCNN_TF_Graph.cc.
249 if ((samples == 0) || x.empty() || x.front().empty() || x.front().front().empty() || x.front().front().front().empty())
252 if ((samples == -1) || (samples > (
long long int)x.size())) { samples = x.size(); }
255 rows = x.front().size(),
256 cols = x.front().front().size(),
257 depth = x.front().front().front().size();
260 tensorflow::Tensor _x(tensorflow::DT_FLOAT, tensorflow::TensorShape({ samples, rows, cols, depth }));
261 auto input_map = _x.tensor<
float, 4>();
262 for (
long long int s = 0;
s < samples; ++
s) {
263 const auto & sample = x[
s];
264 for (
long long int r = 0;
r < rows; ++
r) {
265 const auto &
row = sample[
r];
266 for (
long long int c = 0;
c < cols; ++
c) {
267 const auto & col =
row[
c];
268 for (
long long int d = 0;
d < depth; ++
d) {
269 input_map(
s,
r,
c,
d) = col[
d];
std::vector< float > run(const std::vector< std::vector< float > > &x)
Definition at line 143 of file RegCNN_TF_Graph.cc.
148 if (ninputs == 1)
return run(x);
150 if ((samples == 0) || x.empty() || x.front().empty() || x.front().front().empty() || x.front().front().front().empty())
153 if ((samples == -1) || (samples > (
long long int)x.size())) { samples = x.size(); }
156 rows = x.front().size(),
157 cols = x.front().front().size(),
158 depth = x.front().front().front().size();
163 std::vector< tensorflow::Tensor > _x;
164 for (
unsigned int ii = 0; ii < ninputs; ++ii){
165 tensorflow::Tensor _xtemp(tensorflow::DT_FLOAT, tensorflow::TensorShape({ samples, rows, cols, 1 }));
166 _x.push_back(_xtemp);
171 for (
long long int s = 0;
s < samples; ++
s) {
172 const auto & sample = x[
s];
173 for (
long long int r = 0;
r < rows; ++
r) {
174 const auto &
row = sample[
r];
175 for (
long long int c = 0;
c < cols; ++
c) {
176 const auto & col =
row[
c];
177 for (
long long int d = 0;
d < depth; ++
d) {
179 _x[
d].tensor<
float, 4>()(
s,
r,
c, 0) = col[
d];
std::vector< float > run(const std::vector< std::vector< float > > &x)
Definition at line 189 of file RegCNN_TF_Graph.cc.
196 if (ninputs == 1)
return run(x);
198 if ((samples == 0) || x.empty() || x.front().empty() || x.front().front().empty() || x.front().front().front().empty())
201 if ((samples == -1) || (samples > (
long long int)x.size())) { samples = x.size(); }
204 rows = x.front().size(),
205 cols = x.front().front().size(),
206 depth = x.front().front().front().size();
210 std::vector< tensorflow::Tensor > _x;
212 for (
unsigned int ii = 0; ii < 3; ++ii){
213 tensorflow::Tensor _xtemp(tensorflow::DT_FLOAT, tensorflow::TensorShape({ samples, rows, cols, 1 }));
214 _x.push_back(_xtemp);
217 tensorflow::Tensor _xtemp(tensorflow::DT_FLOAT, tensorflow::TensorShape({ samples, (
unsigned int)cm.size() }));
218 _x.push_back(_xtemp);
222 for (
long long int s = 0;
s < samples; ++
s) {
223 const auto & sample = x[
s];
224 for (
long long int r = 0;
r < rows; ++
r) {
225 const auto &
row = sample[
r];
226 for (
long long int c = 0;
c < cols; ++
c) {
227 const auto & col =
row[
c];
228 for (
long long int d = 0;
d < depth; ++
d) {
230 _x[
d].tensor<
float, 4>()(
s,
r,
c, 0) = col[
d];
234 for (
unsigned int icm = 0; icm < (
unsigned int)cm.size(); ++icm){
235 _x[depth].tensor<
float, 2>()(
s, icm) = cm[icm];
std::vector< float > run(const std::vector< std::vector< float > > &x)
Definition at line 279 of file RegCNN_TF_Graph.cc.
283 std::vector< std::pair<std::string, tensorflow::Tensor> >
inputs = {
289 std::vector<tensorflow::Tensor>
outputs;
296 size_t samples = 0, nouts = 0;
297 for (
size_t o = 0; o < outputs.size(); ++o)
299 if (o == 0) { samples = outputs[o].dim_size(0); }
300 else if ((
int)samples != outputs[o].dim_size(0))
302 throw std::string(
"TF outputs size inconsistent.");
304 nouts += outputs[o].dim_size(1);
308 std::vector< std::vector< float > >
result;
309 result.resize(samples, std::vector< float >(nouts));
312 for (
size_t o = 0; o < outputs.size(); ++o)
314 auto output_map = outputs[o].tensor<
float, 2>();
316 size_t n = outputs[o].dim_size(1);
317 for (
size_t s = 0;
s < samples; ++
s) {
318 std::vector< float > & vs = result[
s];
319 for (
size_t i = 0; i <
n; ++i) {
320 vs[idx0 + i] = output_map(
s, i);
330 return std::vector< std::vector< float > >();
std::vector< std::string > fInputNames
std::vector< std::string > fOutputNames
tensorflow::Session * fSession
QTextStream & endl(QTextStream &s)
Definition at line 334 of file RegCNN_TF_Graph.cc.
339 std::vector< std::pair<std::string, tensorflow::Tensor> >
inputs;
340 for (
size_t ii = 0; ii < x.size(); ++ii){
341 inputs.push_back(std::make_pair(
fInputNames[ii], x[ii]));
346 std::vector<tensorflow::Tensor>
outputs;
353 size_t samples = 0, nouts = 0;
354 for (
size_t o = 0; o < outputs.size(); ++o)
356 if (o == 0) { samples = outputs[o].dim_size(0); }
357 else if ((
int)samples != outputs[o].dim_size(0))
359 throw std::string(
"TF outputs size inconsistent.");
361 nouts += outputs[o].dim_size(1);
365 std::vector< std::vector< float > >
result;
366 result.resize(samples, std::vector< float >(nouts));
369 for (
size_t o = 0; o < outputs.size(); ++o)
371 auto output_map = outputs[o].tensor<
float, 2>();
373 size_t n = outputs[o].dim_size(1);
374 for (
size_t s = 0;
s < samples; ++
s) {
375 std::vector< float > & vs = result[
s];
376 for (
size_t i = 0; i <
n; ++i) {
377 vs[idx0 + i] = output_map(
s, i);
387 return std::vector< std::vector< float > >();
std::vector< std::string > fInputNames
std::vector< std::string > fOutputNames
tensorflow::Session * fSession
QTextStream & endl(QTextStream &s)
tensorflow::Session* tf::RegCNNGraph::fSession |
|
private |
The documentation for this class was generated from the following files: