Graph.cxx
Go to the documentation of this file.
1 #include "WireCellPgraph/Graph.h"
2 #include "WireCellUtil/Type.h"
3 
4 #include <unordered_map>
5 #include <unordered_set>
6 
7 
9 using namespace WireCell::Pgraph;
10 
12  : l(Log::logger("pgraph"))
13 {
14 }
15 
16 void Graph::add_node(Node* node)
17 {
18  m_nodes.insert(node);
19 }
20 
21 bool Graph::connect(Node* tail, Node* head, size_t tpind, size_t hpind)
22 {
23  Port& tport = tail->output_ports()[tpind];
24  Port& hport = head->input_ports()[hpind];
25  if (tport.signature() != hport.signature()) {
26  l->critical ("port signature mismatch: \"{}\" != \"{}\"",
27  tport.signature (), hport.signature());
28  THROW(ValueError() << errmsg{"port signature mismatch"});
29  return false;
30  }
31 
32  m_edges.push_back(std::make_pair(tail,head));
33  Edge edge = std::make_shared<Queue>();
34 
35  tport.plug(edge);
36  hport.plug(edge);
37 
38  add_node(tail);
39  add_node(head);
40 
41  m_edges_forward[tail].push_back(head);
42  m_edges_backward[head].push_back(tail);
43 
44  SPDLOG_LOGGER_TRACE(l, "connect {}:({}:{}) -> {}({}:{})",
45  tail->ident(),
46  demangle(tport.signature ()),
47  tpind,
48  head->ident(),
49  demangle(hport.signature()),
50  hpind);
51 
52  return true;
53 }
54 
55 std::vector<Node*> Graph::sort_kahn() {
56 
57  std::unordered_map<Node*, int> nincoming;
58  for (auto th : m_edges) {
59 
60  nincoming[th.first] += 0; // make sure all nodes represented
61  nincoming[th.second] += 1;
62  }
63 
64  std::vector<Node*> ret;
65  std::unordered_set<Node*> seeds;
66 
67  for (auto it : nincoming) {
68  if (it.second == 0) {
69  seeds.insert(it.first);
70  }
71  }
72 
73  while (!seeds.empty()) {
74  Node* t = *seeds.begin();
75  seeds.erase(t);
76  ret.push_back(t);
77 
78  for (auto h : m_edges_forward[t]) {
79  nincoming[h] -= 1;
80  if (nincoming[h] == 0) {
81  seeds.insert(h);
82  }
83  }
84  }
85  return ret;
86 }
87 
89 {
90  int count = 0;
91  for (auto parent : m_edges_backward[node]) {
92  bool ok = call_node(parent);
93  if (ok) {
94  ++count;
95  continue;
96  }
97  count += execute_upstream(parent);
98  }
99  bool ok = call_node(node);
100  if (ok) { ++count; }
101  return count;
102 }
103 
104 // this bool indicates exception, and is probably ignored
106 {
107  auto nodes = sort_kahn();
108  l->debug("executing with {} nodes", nodes.size());
109 
110  while (true) {
111 
112  int count = 0;
113  bool did_something = false;
114 
115  for (auto nit = nodes.rbegin(); nit != nodes.rend(); ++nit, ++count) {
116  Node* node = *nit;
117 
118  bool ok = call_node(node);
119  if (ok) {
120  SPDLOG_LOGGER_TRACE(l, "ran node {}: {}", count, node->ident());
121  did_something = true;
122  break; // start again from bottom of graph
123  }
124 
125  }
126 
127  if (!did_something) {
128  return true; // it's okay to do nothing.
129  }
130  }
131  return true; // shouldn't reach
132 }
133 
135 {
136  if (!node) {
137  l->error("graph call: got nullptr node");
138  return false;
139  }
140  bool ok = (*node)();
141  // this can be very noisy but useful to uncomment to understand
142  // the graph execution order.
143  if (ok) {
144  SPDLOG_LOGGER_TRACE(l, "graph call [{}] called: {}", ok, node->ident());
145  }
146  return ok;
147 }
148 
150 {
151  for (auto n : m_nodes) {
152  if (!n->connected()) {
153  return false;
154  }
155  }
156  return true;
157 }
158 
std::unordered_set< Node * > m_nodes
Definition: Graph.h:67
PortList & input_ports()
Definition: Node.h:29
Edge plug(Edge edge)
Definition: Port.cxx:23
boost::error_info< struct tag_errmsg, std::string > errmsg
Definition: Exceptions.h:54
virtual std::string ident()=0
std::shared_ptr< Queue > Edge
Definition: Port.h:27
const std::string & signature()
Definition: Port.cxx:73
void add_node(Node *node)
Definition: Graph.cxx:16
bool call_node(Node *node)
Definition: Graph.cxx:134
std::pair< int, int > edge(const realseq_t &wave)
Definition: Waveform.cxx:121
static QStrList * l
Definition: config.cpp:1044
std::vector< Node * > sort_kahn()
Definition: Graph.cxx:55
std::vector< std::pair< Node *, Node * > > m_edges
Definition: Graph.h:66
bool connect(Node *tail, Node *head, size_t tpind=0, size_t hpind=0)
Definition: Graph.cxx:21
#define nodes
std::string demangle(const std::string &name)
Definition: Type.cxx:6
#define THROW(e)
Definition: Exceptions.h:25
std::unordered_map< Node *, std::vector< Node * > > m_edges_backward
Definition: Graph.h:68
Log::logptr_t l
Definition: Graph.h:70
logptr_t logger(std::string name)
Definition: Logging.cxx:71
Thrown when a wrong value has been encountered.
Definition: Exceptions.h:37
std::vector< TrajPoint > seeds
Definition: DataStructs.cxx:13
#define SPDLOG_LOGGER_TRACE(logger,...)
Definition: spdlog.h:319
int execute_upstream(Node *node)
Definition: Graph.cxx:88
std::unordered_map< Node *, std::vector< Node * > > m_edges_forward
Definition: Graph.h:68
std::size_t n
Definition: format.h:3399
def parent(G, child, parent_type)
Definition: graph.py:67
h
training ###############################
Definition: train_cnn.py:186
PortList & output_ports()
Definition: Node.h:32