1 ////////////////////////////////////////////////////////////////////////
2 // Class: InfillChannels
3 // Plugin Type: analyzer (art v3_05_01)
4 // File:
5 //
6 // Alex Wilkinson - 08/03/21
7 ////////////////////////////////////////////////////////////////////////
16 #include "fhiclcpp/ParameterSet.h"
22 #include "lardataobj/RawData/raw.h"
26 #include <TGeoVolume.h>
28 #include <vector>
29 #include <string>
30 #include <map>
31 #include <algorithm>
32 #include <iterator>
33 #include <array>
35 #include <torch/script.h>
36 #include <torch/torch.h>
38 namespace Infill
39 {
40  class InfillChannels;
41 }
44 {
45 public:
46  explicit InfillChannels(fhicl::ParameterSet const& p);
48  // Plugins should not be copied or assigned.
49  InfillChannels(InfillChannels const&) = delete;
50  InfillChannels(InfillChannels&&) = delete;
51  InfillChannels& operator=(InfillChannels const&) = delete;
54  // Required functions.
55  void produce(art::Event& e) override;
57  // Selected optional functions.
58  void beginJob() override;
59  void endJob() override;
61 private:
62  // Declare member data here.
65  std::set<raw::ChannelID_t> fBadChannels;
66  std::set<raw::ChannelID_t> fNoisyChannels;
67  std::set<raw::ChannelID_t> fDeadChannels;
69  std::set<readout::ROPID> fActiveRops;
77 };
80  : EDProducer{p},
81  fNetworkPath (p.get<std::string> ("NetworkPath")),
82  fNetworkNameInduction (p.get<std::string> ("NetworkNameInduction")),
83  fNetworkNameCollection (p.get<std::string> ("NetworkNameCollection")),
84  fInputLabel (p.get<std::string> ("InputLabel"))
85 {
86  consumes<std::vector<raw::RawDigit>>(fInputLabel);
88  produces<std::vector<raw::RawDigit>>();
89 }
92 {
93  auto const detProp = art::ServiceHandle<detinfo::DetectorPropertiesService>()->DataFor(e);
94  // Networks expect a fixed image size
95  if (detProp.NumberTimeSamples() > 6000) {
96  std::cerr << " Networks cannot handle more than 6000 time ticks\n";
97  std::abort();
98  }
100  typedef std::array<short, 6000> vecAdc;
101  std::map<raw::ChannelID_t, vecAdc> infilledAdcs;
102  torch::Tensor maskedRopTensor;
103  torch::Tensor infilledRopTensor;
105  auto digs = e.getHandle<std::vector<raw::RawDigit> >(fInputLabel);
107  // Get infilled adc ROP by ROP
108  for (const readout::ROPID& currentRop : fActiveRops) {
109  maskedRopTensor = torch::zeros(
110  {1, 1 ,6000, fGeom->Nchannels(currentRop)}, torch::dtype(torch::kFloat32).device(torch::kCPU).requires_grad(false)
111  );
112  auto maskedRopTensorAccess = maskedRopTensor.accessor<float, 4>();
114  const raw::ChannelID_t firstCh = fGeom->FirstChannelInROP(currentRop);
116  // Fill ROP image
117  for (const raw::RawDigit& dig : *digs) {
118  if (fDeadChannels.count(dig.Channel())) continue;
120  readout::ROPID rop = fGeom->ChannelToROP(dig.Channel());
121  if (rop != currentRop) continue;
123  raw::RawDigit::ADCvector_t adcs(dig.Samples());
124  raw::Uncompress(dig.ADCs(), adcs, dig.Compression());
126  for (unsigned int tick = 0; tick < adcs.size(); ++tick) {
127  const int adc = adcs[tick] ? int(adcs[tick]) - dig.GetPedestal() : 0;
129  maskedRopTensorAccess[0][0][tick][dig.Channel() - firstCh] = adc;
130  }
131  }
133  // Do the Infill
134  std::vector<torch::jit::IValue> inputs;
135  inputs.push_back(maskedRopTensor);
136  if (fGeom->SignalType(currentRop) == geo::kInduction) {
137  torch::NoGradGuard no_grad_guard;
138  infilledRopTensor = fInductionModule.forward(inputs).toTensor().detach();
139  }
140  else if (fGeom->SignalType(currentRop) == geo::kCollection) {
141  torch::NoGradGuard no_grad_guard;
142  infilledRopTensor = fCollectionModule.forward(inputs).toTensor().detach();
143  }
145  // Store infilled ADC of dead channels
146  auto infilledRopTensorAccess = infilledRopTensor.accessor<float, 4>();
147  for (const raw::ChannelID_t ch : fDeadChannels) {
148  if (fGeom->ChannelToROP(ch) == currentRop) {
149  for (unsigned int tick = 0; tick < detProp.NumberTimeSamples(); ++tick) {
150  infilledAdcs[ch][tick] = (short)std::round(infilledRopTensorAccess[0][0][tick][ch - firstCh]);
151  }
152  }
153  } // Could break early if ch > last ch in currentRop to save a bit of looping?
154  }
156  // Encode infilled ADC into RawDigit and put back onto event
157  auto infilledDigs = std::make_unique<std::vector<raw::RawDigit>>();
158  *infilledDigs = *digs;
159  for (raw::RawDigit& dig : *infilledDigs) {
160  if (infilledAdcs.count(dig.Channel())) {
161  raw::RawDigit::ADCvector_t infilledAdc(
162  infilledAdcs[dig.Channel()].begin(), (infilledAdcs[dig.Channel()].begin() + detProp.NumberTimeSamples())
163  );
165  // Get new pedestal
166  auto infilledAdcMin = std::min_element(infilledAdc.begin(), infilledAdc.end());
167  short ped = *infilledAdcMin < 0 ? std::abs(*infilledAdcMin) + 1 : 0;
168  for (short& adc : infilledAdc) adc += ped;
170  raw::Compress(infilledAdc, dig.Compression()); // need to consider compression parameters
171  dig = raw::RawDigit(dig.Channel(), dig.Samples(), infilledAdc, dig.Compression());
172  dig.SetPedestal(ped);
173  }
174  }
175  e.put(std::move(infilledDigs));
176 }
179 {
182  // Dead channels = bad channels + noisy channels
186  std::merge(
187  fBadChannels.begin(), fBadChannels.end(), fNoisyChannels.begin(), fNoisyChannels.end(),
188  std::inserter(fDeadChannels, fDeadChannels.begin())
189  );
191  // Get active ROPs (not facing a wall and has dead channels)
194  for (iRop = rBegin; iRop != rEnd; ++iRop) { // Iterate over ROPs in the detector
195  bool hasDeadCh = false;
196  for (raw::ChannelID_t ch : fDeadChannels) {
197  if (fGeom->ChannelToROP(ch) == *iRop) {
198  hasDeadCh = true;
199  break;
200  }
201  }
202  if (!hasDeadCh) continue; // Don't need to infill ROPs without dead channels
204  for (const geo::TPCID tpcId : fGeom->ROPtoTPCs(*iRop)) {
205  const geo::TPCGeo tpc = fGeom->TPC(tpcId);
206  const TGeoVolume* tpcVol = tpc.ActiveVolume();
208  if (tpcVol->Capacity() > 1000000) { // At least one of the ROP's TPCIDs needs to be active
209  // Networks expect a fixed image size
210  if(fGeom->SignalType(*iRop) == geo::kInduction && fGeom->Nchannels(*iRop) > 800) {
211  std::cerr << " Induction view network cannot handle more then 800 channels\n";
212  std::abort();
213  }
214  if(fGeom->SignalType(*iRop) == geo::kCollection && fGeom->Nchannels(*iRop) > 480) {
215  std::cerr << " Collection view network cannot handle more then 400 channels\n";
216  std::abort();
217  }
219  fActiveRops.insert(*iRop);
220  break;
221  }
222  }
223  }
225  // Check dead channels resemble the dead channels used for training
226  raw::ChannelID_t chGap = 1;
227  for (const raw::ChannelID_t ch : fDeadChannels) {
228  if (fDeadChannels.count(ch + 1)) {
229  ++chGap;
230  continue;
231  }
232  if (fGeom->ChannelToROP(ch - chGap) == fGeom->ChannelToROP(ch + 1)) {
233  if (fGeom->SignalType(ch) == geo::kCollection && chGap > 3) {
234  std::cerr << "There are dead channel gap larger than what was seen in training --- ";
235  std::cerr << "**Consider retraining collection plane infill network**" << std::endl;
236  }
237  else if (fGeom->SignalType(ch) == geo::kInduction && chGap > 2) {
238  std::cerr << "There are dead channel gap larger than what was seen in training --- ";
239  std::cerr << "**Consider retraining induction plane infill network**" << std::endl;
240  }
241  }
242  chGap = 1;
243  }
245  // Load torchscripts
246  std::cout << "Loading modules..." << std::endl;
247  const char* networkPath = std::getenv(fNetworkPath.c_str());
248  if (networkPath == nullptr) {
249  std::cerr << " Environment variable " << fNetworkPath << " was not found";
250  std::abort();
251  }
252  const std::string networkLocInduction = std::string(networkPath) + "/" + fNetworkNameInduction;
253  const std::string networkLocCollection = std::string(networkPath) + "/" + fNetworkNameCollection;
255  try {
256  fInductionModule = torch::jit::load(networkLocInduction);
257  std::cout << "Induction module loaded from " << networkLocInduction <<std::endl;
258  }
259  catch (const c10::Error& err) {
260  std::cerr << "error loading the model\n";
261  std::cerr << err.what();
262  }
263  try {
264  fCollectionModule = torch::jit::load(networkLocCollection);
265  std::cout << "Collection module loaded from " << networkLocCollection << std::endl;
266  }
267  catch (const c10::Error& err) {
268  std::cerr << "error loading the model\n";
269  std::cerr << err.what();
270  }
271 }
274 {
275  // Implementation of optional member function here.
276 }
