EmTrackClusterId3out_module.cc
Go to the documentation of this file.
1 /////////////////////////////////////////////////////////////////////////////////
2 // Class: EmTrackClusterId
3 // Module Type: producer
4 // File: EmTrackClusterId_module.cc
5 // Authors: dorota.stefan@cern.ch pplonski86@gmail.com robert.sulej@cern.ch
6 //
7 // Module applies CNN to 2D image made of deconvoluted wire waveforms in order
8 // to distinguish EM-like activity from track-like objects. New clusters of
9 // hits are produced to include also unclustered hits and tag everything in
10 // a common way.
11 // NOTE: This module uses 3-output CNN models, see EmTrackMichelClusterId for
12 // usage of 4-output models and EmTrackClusterId2out_module.cc for 2-output
13 // models.
14 //
15 /////////////////////////////////////////////////////////////////////////////////
16 
24 #include "fhiclcpp/types/Atom.h"
26 #include "fhiclcpp/types/Table.h"
28 
36 
38 
39 #include <memory>
40 
41 namespace nnet {
42 
44  public:
45  // these types to be replaced with use of feature proposed in redmine #12602
46  typedef std::unordered_map<unsigned int, std::vector<size_t>> view_keymap;
47  typedef std::unordered_map<unsigned int, view_keymap> tpc_view_keymap;
48  typedef std::unordered_map<unsigned int, tpc_view_keymap> cryo_tpc_view_keymap;
49 
50  struct Config {
51  using Name = fhicl::Name;
53 
56  Comment("number of samples processed in one batch")};
57 
59  Name("WireLabel"),
60  Comment("tag of deconvoluted ADC on wires (recob::Wire)")};
61 
63  Comment("tag of hits to be EM/track tagged")};
64 
66  Name("ClusterModuleLabel"),
67  Comment("tag of clusters to be used as a source of EM/track tagged new clusters (incl. "
68  "single-hit clusters ) using accumulated results from hits")};
69 
71  Name("TrackModuleLabel"),
72  Comment("tag of 3D tracks to be EM/track tagged using accumulated results from hits in the "
73  "best 2D projection")};
74 
76  Name("Views"),
77  Comment("tag clusters in selected views only, or in all views if empty list")};
78  };
80  explicit EmTrackClusterId(Parameters const& p);
81 
82  EmTrackClusterId(EmTrackClusterId const&) = delete;
86 
87  private:
88  void produce(art::Event& e) override;
89 
90  bool isViewSelected(int view) const;
91 
92  size_t fBatchSize;
94  anab::MVAWriter<3> fMVAWriter; // <-------------- using 3-output CNN model
95 
101 
102  std::vector<int> fViews;
103 
104  art::InputTag fNewClustersTag; // input tag for the clusters produced by this module
105  };
106  // ------------------------------------------------------
107 
109  : EDProducer{config}
110  , fBatchSize(config().BatchSize())
112  , fMVAWriter(producesCollector(), "emtrack")
113  , fWireProducerLabel(config().WireLabel())
114  , fHitModuleLabel(config().HitModuleLabel())
115  , fClusterModuleLabel(config().ClusterModuleLabel())
116  , fTrackModuleLabel(config().TrackModuleLabel())
117  , fViews(config().Views())
118  ,
119 
120  fNewClustersTag(config.get_PSet().get<std::string>("module_label"),
121  "",
123  {
125 
126  if (!fClusterModuleLabel.label().empty()) {
127  produces<std::vector<recob::Cluster>>();
128  produces<art::Assns<recob::Cluster, recob::Hit>>();
129 
131  fDoClusters = true;
132  }
133  else {
134  fDoClusters = false;
135  }
136 
137  if (!fTrackModuleLabel.label().empty()) {
139  fDoTracks = true;
140  }
141  else {
142  fDoTracks = false;
143  }
144  }
145  // ------------------------------------------------------
146 
147  void
149  {
150  mf::LogVerbatim("EmTrackClusterId") << "next event: " << evt.run() << " / " << evt.id().event();
151 
152  auto wireHandle = evt.getValidHandle<std::vector<recob::Wire>>(fWireProducerLabel);
153 
154  unsigned int cryo, tpc, view;
155 
156  // ******************* get and sort hits ********************
157  auto hitListHandle = evt.getValidHandle<std::vector<recob::Hit>>(fHitModuleLabel);
158  std::vector<art::Ptr<recob::Hit>> hitPtrList;
159  art::fill_ptr_vector(hitPtrList, hitListHandle);
160 
162  for (auto const& h : hitPtrList) {
163  view = h->WireID().Plane;
164  if (!isViewSelected(view)) continue;
165 
166  cryo = h->WireID().Cryostat;
167  tpc = h->WireID().TPC;
168 
169  hitMap[cryo][tpc][view].push_back(h.key());
170  }
171 
172  // ********************* classify hits **********************
173  auto hitID = fMVAWriter.initOutputs<recob::Hit>(
174  fHitModuleLabel, hitPtrList.size(), fPointIdAlg.outputLabels());
175 
176  auto const clockData = art::ServiceHandle<detinfo::DetectorClocksService const>()->DataFor(evt);
177  auto const detProp =
179 
180  std::vector<char> hitInFA(
181  hitPtrList.size(),
182  0); // tag hits in fid. area as 1, use 0 for hits close to the projectrion edges
183  for (auto const& [cryo, tpcs] : hitMap) {
184  for (auto const& [tpc, views] : tpcs) {
185  for (auto const& pview : views) {
186  view = pview.first;
187  if (!isViewSelected(view)) continue; // should not happen, hits were selected
188 
189  fPointIdAlg.setWireDriftData(clockData, detProp, *wireHandle, view, tpc, cryo);
190 
191  // (1) do all hits in this plane ------------------------------------------------
192  for (size_t idx = 0; idx < pview.second.size(); idx += fBatchSize) {
193  std::vector<std::pair<unsigned int, float>> points;
194  std::vector<size_t> keys;
195  for (size_t k = 0; k < fBatchSize; ++k) {
196  if (idx + k >= pview.second.size()) { break; } // careful about the tail
197 
198  size_t h = pview.second[idx + k]; // h is the Ptr< recob::Hit >::key()
199  const recob::Hit& hit = *(hitPtrList[h]);
200  points.emplace_back(hit.WireID().Wire, hit.PeakTime());
201  keys.push_back(h);
202  }
203 
204  auto batch_out = fPointIdAlg.predictIdVectors(points);
205  if (points.size() != batch_out.size()) {
206  throw cet::exception("EmTrackClusterId") << "hits processing failed" << std::endl;
207  }
208 
209  for (size_t k = 0; k < points.size(); ++k) {
210  size_t h = keys[k];
211  fMVAWriter.setOutput(hitID, h, batch_out[k]);
212  if (fPointIdAlg.isInsideFiducialRegion(points[k].first, points[k].second)) {
213  hitInFA[h] = 1;
214  }
215  }
216  } // hits done ------------------------------------------------------------------
217  }
218  }
219  }
220 
221  // (2) do clusters when hits are ready in all planes ----------------------------------------
222  if (fDoClusters) {
223  // **************** prepare for new clusters ****************
224  auto clusters = std::make_unique<std::vector<recob::Cluster>>();
225  auto clu2hit = std::make_unique<art::Assns<recob::Cluster, recob::Hit>>();
226 
227  // ************** get and sort input clusters ***************
228  auto cluListHandle = evt.getValidHandle<std::vector<recob::Cluster>>(fClusterModuleLabel);
229  std::vector<art::Ptr<recob::Cluster>> cluPtrList;
230  art::fill_ptr_vector(cluPtrList, cluListHandle);
231 
233  for (auto const& c : cluPtrList) {
234  view = c->Plane().Plane;
235  if (!isViewSelected(view)) continue;
236 
237  cryo = c->Plane().Cryostat;
238  tpc = c->Plane().TPC;
239 
240  cluMap[cryo][tpc][view].push_back(c.key());
241  }
242 
243  auto cluID =
245 
246  unsigned int cidx = 0; // new clusters index
247  art::FindManyP<recob::Hit> hitsFromClusters(cluListHandle, evt, fClusterModuleLabel);
248  std::vector<bool> hitUsed(hitPtrList.size(), false); // tag hits used in clusters
249  for (auto const& pcryo : cluMap) {
250  cryo = pcryo.first;
251  for (auto const& ptpc : pcryo.second) {
252  tpc = ptpc.first;
253  for (auto const& pview : ptpc.second) {
254  view = pview.first;
255  if (!isViewSelected(view)) continue; // should not happen, clusters were pre-selected
256 
257  for (size_t c : pview.second) // c is the Ptr< recob::Cluster >::key()
258  {
259  auto v = hitsFromClusters.at(c);
260  if (v.empty()) continue;
261 
262  for (auto const& hit : v) {
263  if (hitUsed[hit.key()]) {
264  mf::LogWarning("EmTrackClusterId") << "hit already used in another cluster";
265  }
266  hitUsed[hit.key()] = true;
267  }
268 
269  auto vout = fMVAWriter.getOutput<recob::Hit>(
270  v, [&](art::Ptr<recob::Hit> const& ptr) { return (float)hitInFA[ptr.key()]; });
271 
272  float pvalue = vout[0] / (vout[0] + vout[1]);
273  mf::LogVerbatim("EmTrackClusterId") << "cluster in tpc:" << tpc << " view:" << view
274  << " size:" << v.size() << " p:" << pvalue;
275 
276  clusters->emplace_back(recob::Cluster(0.0F,
277  0.0F,
278  0.0F,
279  0.0F,
280  0.0F,
281  0.0F,
282  0.0F,
283  0.0F,
284  0.0F,
285  0.0F,
286  0.0F,
287  0.0F,
288  0.0F,
289  0.0F,
290  0.0F,
291  0.0F,
292  0.0F,
293  0.0F,
294  v.size(),
295  0.0F,
296  0.0F,
297  cidx,
298  (geo::View_t)view,
299  v.front()->WireID().planeID()));
300  util::CreateAssn(*this, evt, *clusters, v, *clu2hit);
301  cidx++;
302 
303  fMVAWriter.addOutput(cluID, vout); // add copy of the input cluster
304  }
305 
306  // (2b) make single-hit clusters --------------------------------------------
307  for (size_t h : hitMap[cryo][tpc][view]) // h is the Ptr< recob::Hit >::key()
308  {
309  if (hitUsed[h]) continue;
310 
311  auto vout = fMVAWriter.getOutput<recob::Hit>(h);
312  float pvalue = vout[0] / (vout[0] + vout[1]);
313 
314  mf::LogVerbatim("EmTrackClusterId")
315  << "single hit in tpc:" << tpc << " view:" << view
316  << " wire:" << hitPtrList[h]->WireID().Wire
317  << " drift:" << hitPtrList[h]->PeakTime() << " p:" << pvalue;
318 
319  art::PtrVector<recob::Hit> cluster_hits;
320  cluster_hits.push_back(hitPtrList[h]);
321  clusters->emplace_back(recob::Cluster(0.0F,
322  0.0F,
323  0.0F,
324  0.0F,
325  0.0F,
326  0.0F,
327  0.0F,
328  0.0F,
329  0.0F,
330  0.0F,
331  0.0F,
332  0.0F,
333  0.0F,
334  0.0F,
335  0.0F,
336  0.0F,
337  0.0F,
338  0.0F,
339  1,
340  0.0F,
341  0.0F,
342  cidx,
343  (geo::View_t)view,
344  hitPtrList[h]->WireID().planeID()));
345  util::CreateAssn(*this, evt, *clusters, cluster_hits, *clu2hit);
346  cidx++;
347 
348  fMVAWriter.addOutput(cluID, vout); // add single-hit cluster tagging unclutered hit
349  }
350  mf::LogVerbatim("EmTrackClusterId")
351  << "...produced " << cidx - pview.second.size() << " single-hit clusters.";
352  }
353  }
354  }
355 
356  evt.put(std::move(clusters));
357  evt.put(std::move(clu2hit));
358  } // all clusters done ----------------------------------------------------------------------
359 
360  // (3) do tracks when all hits in all cryo/tpc/plane are done -------------------------------
361  if (fDoTracks) {
362  auto trkListHandle = evt.getValidHandle<std::vector<recob::Track>>(fTrackModuleLabel);
363  art::FindManyP<recob::Hit> hitsFromTracks(trkListHandle, evt, fTrackModuleLabel);
364  std::vector<std::vector<art::Ptr<recob::Hit>>> trkHitPtrList(trkListHandle->size());
365  for (size_t t = 0; t < trkListHandle->size(); ++t) {
366  auto v = hitsFromTracks.at(t);
367  size_t nh[3] = {0, 0, 0};
368  for (auto const& hptr : v) {
369  ++nh[hptr->View()];
370  }
371  size_t best_view = 2; // collection
372  if ((nh[0] >= nh[1]) && (nh[0] > 2 * nh[2])) best_view = 0; // ind1
373  if ((nh[1] >= nh[0]) && (nh[1] > 2 * nh[2])) best_view = 1; // ind2
374 
375  size_t k = 0;
376  while (!isViewSelected(best_view)) {
377  best_view = (best_view + 1) % 3;
378  if (++k > 3) {
379  throw cet::exception("EmTrackClusterId") << "No views selected at all?" << std::endl;
380  }
381  }
382 
383  for (auto const& hptr : v) {
384  if (hptr->View() == best_view) trkHitPtrList[t].emplace_back(hptr);
385  }
386  }
387 
388  auto trkID = fMVAWriter.initOutputs<recob::Track>(
389  fTrackModuleLabel, trkHitPtrList.size(), fPointIdAlg.outputLabels());
390  for (size_t t = 0; t < trkHitPtrList.size(); ++t) // t is the Ptr< recob::Track >::key()
391  {
392  auto vout =
393  fMVAWriter.getOutput<recob::Hit>(trkHitPtrList[t], [&](art::Ptr<recob::Hit> const& ptr) {
394  return (float)hitInFA[ptr.key()];
395  });
396  fMVAWriter.setOutput(trkID, t, vout);
397  }
398  }
399  // tracks done ------------------------------------------------------------------------------
400 
401  fMVAWriter.saveOutputs(evt);
402  }
403  // ------------------------------------------------------
404 
405  bool
407  {
408  if (fViews.empty())
409  return true;
410  else {
411  for (auto k : fViews)
412  if (k == view) { return true; }
413  return false;
414  }
415  }
416  // ------------------------------------------------------
417 
419 
420 }
bool isInsideFiducialRegion(unsigned int wire, float drift) const
Definition: PointIdAlg.cxx:316
MaybeLogger_< ELseverityLevel::ELsev_info, true > LogVerbatim
std::vector< std::string > const & outputLabels() const
network output labels
Definition: PointIdAlg.h:112
enum geo::_plane_proj View_t
Enumerate the possible plane projections.
std::string string
Definition: nybbler.cc:12
geo::WireID WireID() const
Definition: Hit.h:233
void setOutput(FVector_ID id, size_t key, std::array< float, N > const &values)
Definition: MVAWriter.h:175
EDProducer(fhicl::ParameterSet const &pset)
Definition: EDProducer.h:20
fhicl::Atom< art::InputTag > WireLabel
ChannelGroupService::Name Name
void produce(art::Event &e) override
Set of hits with a 2D structure.
Definition: Cluster.h:71
WireID_t Wire
Index of the wire within its plane.
Definition: geo_types.h:580
fhicl::Atom< art::InputTag > TrackModuleLabel
art framework interface to geometry description
bool isViewSelected(int view) const
std::unordered_map< unsigned int, view_keymap > tpc_view_keymap
std::string const & label() const noexcept
Definition: InputTag.cc:79
FVector_ID initOutputs(std::string const &dataTag, size_t dataSize, std::vector< std::string > const &names=std::vector< std::string >(N,""))
const double e
fhicl::Atom< art::InputTag > ClusterModuleLabel
#define DEFINE_ART_MODULE(klass)
Definition: ModuleMacros.h:67
IDparameter< geo::WireID > WireID
Member type of validated geo::WireID parameter.
void push_back(Ptr< U > const &p)
Definition: PtrVector.h:435
static Config * config
Definition: config.cpp:1054
def move(depos, offset)
Definition: depos.py:107
fhicl::Atom< art::InputTag > HitModuleLabel
ValidHandle< PROD > getValidHandle(InputTag const &tag) const
Definition: DataViewImpl.h:441
p
Definition: test.py:223
ProductID put(std::unique_ptr< PROD > &&edp, std::string const &instance={})
Definition: DataViewImpl.h:686
bool CreateAssn(PRODUCER const &prod, art::Event &evt, std::vector< T > const &a, art::Ptr< U > const &b, art::Assns< U, T > &assn, std::string a_instance, size_t indx=UINT_MAX)
Creates a single one-to-one association.
RunNumber_t run() const
Definition: DataViewImpl.cc:71
std::array< float, N > getOutput(std::vector< art::Ptr< T > > const &items) const
Definition: MVAWriter.h:189
void saveOutputs(art::Event &evt)
Check consistency and save all the results in the event.
Definition: MVAWriter.h:325
Detector simulation of raw signals on wires.
ProducesCollector & producesCollector() noexcept
void addOutput(FVector_ID id, std::array< float, N > const &values)
Definition: MVAWriter.h:180
float PeakTime() const
Time of the signal peak, in tick units.
Definition: Hit.h:218
Declaration of signal hit object.
std::unordered_map< unsigned int, std::vector< size_t > > view_keymap
bool setWireDriftData(const detinfo::DetectorClocksData &clock_data, const detinfo::DetectorPropertiesData &det_prop, const std::vector< recob::Wire > &wires, unsigned int plane, unsigned int tpc, unsigned int cryo)
#define Comment
MaybeLogger_< ELseverityLevel::ELsev_warning, false > LogWarning
Provides recob::Track data product.
EventNumber_t event() const
Definition: EventID.h:116
2D representation of charge deposited in the TDC/wire plane
Definition: Hit.h:48
constexpr PlaneID const & planeID() const
Definition: geo_types.h:638
TCEvent evt
Definition: DataStructs.cxx:7
void fill_ptr_vector(std::vector< Ptr< T >> &ptrs, H const &h)
Definition: Ptr.h:297
std::vector< std::vector< float > > predictIdVectors(std::vector< std::pair< unsigned int, float >> points) const
Definition: PointIdAlg.cxx:240
EventID id() const
Definition: Event.cc:34
EmTrackClusterId & operator=(EmTrackClusterId const &)=delete
Track from a non-cascading particle.A recob::Track consists of a recob::TrackTrajectory, plus additional members relevant for a "fitted" track:
Definition: Track.h:49
std::unordered_map< unsigned int, tpc_view_keymap > cryo_tpc_view_keymap
cet::coded_exception< error, detail::translate > exception
Definition: exception.h:33
QTextStream & endl(QTextStream &s)