prepare_data_cnn_vtx-id.py
Go to the documentation of this file.
1 from ROOT import TFile
2 import numpy as np
3 from sys import argv
4 from os import listdir
5 from os.path import isfile, join
6 import os, json
7 import argparse
8 
9 from utils import read_config, get_data, get_patch, get_vertices, get_nu_vertices
10 
11 def main(argv):
12 
13  parser = argparse.ArgumentParser(description='Makes training data set for various vertex/decay ID')
14  parser.add_argument('-c', '--config', help="JSON with script configuration", default='config.json')
15  parser.add_argument('-t', '--type', help="Input file format")
16  parser.add_argument('-i', '--input', help="Input directory")
17  parser.add_argument('-o', '--output', help="Output directory")
18  parser.add_argument('-v', '--view', help="view")
19  args = parser.parse_args()
20 
21  config = read_config(args.config)
22 
23  print '#'*50,'\nPrepare data for CNN'
24  INPUT_TYPE = config['prepare_data_vtx_id']['input_type']
25  INPUT_DIR = config['prepare_data_vtx_id']['input_dir']
26  OUTPUT_DIR = config['prepare_data_vtx_id']['output_dir']
27  PATCH_SIZE_W = config['prepare_data_vtx_id']['patch_size_w']
28  PATCH_SIZE_D = config['prepare_data_vtx_id']['patch_size_d']
29  print 'Using %s as input dir, and %s as output dir' % (INPUT_DIR, OUTPUT_DIR)
30  print '#'*50
31 
32  rootModule = config['prepare_data_vtx_id']['module_name'] # larsoft module name used for data dumps in ROOT format
33  selected_view_idx = config['prepare_data_vtx_id']['selected_view_idx'] # set the view id
34  nearby_empty = config['prepare_data_vtx_id']['nearby_empty'] # number of patches near each vtx, but with empty area in the central pixel
35  nearby_on_track = config['prepare_data_vtx_id']['nearby_on_track'] # number of patches on tracks or showers, somewhere close to each vtx
36  crop_event = config['prepare_data_vtx_id']['crop_event'] # use true only if no crop on LArSoft level and not a noise dump
37 
38  print 'Using', nearby_empty, 'empty and', nearby_on_track, 'on track patches per each verex in view', selected_view_idx
39 
40  max_capacity = 300000
41  db = np.zeros((max_capacity, PATCH_SIZE_W, PATCH_SIZE_D), dtype=np.float32)
42  db_y = np.zeros((max_capacity, 5), dtype=np.int32)
43 
44  kHadr = 0x1 # hadronic inelastic scattering
45  kPi0 = 0x2 # pi0 produced in this vertex
46  kDecay = 0x4 # point of particle decay (except pi0 decays)
47  kConv = 0x8 # gamma conversion
48 
49  cnt_ind = 0
50  cnt_vtx = 0
51  cnt_decay = 0
52  cnt_gamma = 0
53  cnt_nu = 0
54  cnt_trk = 0
55  cnt_void = 0
56 
57  fcount = 0
58 
59  event_list = []
60  if INPUT_TYPE == "root":
61  fnames = [f for f in os.listdir(INPUT_DIR) if '.root' in f]
62  for n in fnames:
63  rootFile = TFile(INPUT_DIR+'/'+n)
64  keys = [rootModule+'/'+k.GetName()[:-4] for k in rootFile.Get(rootModule).GetListOfKeys() if '_raw' in k.GetName()]
65  event_list.append((rootFile, keys))
66  else:
67  keys = [f[:-4] for f in os.listdir(INPUT_DIR) if '.raw' in f] # only main part of file name, without extension
68  event_list.append((INPUT_DIR, keys)) # single entry in the list of txt files
69 
70  for entry in event_list:
71  folder = entry[0]
72  event_names = entry[1]
73 
74  for evname in event_names:
75  finfo = evname.split('_')
76  evt_no = finfo[2]
77  tpc_idx = int(finfo[8])
78  view_idx = int(finfo[10])
79 
80  if view_idx != selected_view_idx: continue
81  fcount += 1
82 
83  print 'Process event', fcount, evname, 'NO.', evt_no
84 
85  # get clipped data, margin depends on patch size in drift direction
86  raw, deposit, pdg, tracks, showers = get_data(folder, evname, PATCH_SIZE_D/2 + 2, crop_event)
87  if raw is None:
88  print 'Skip empty event...'
89  continue
90 
91  vtx = get_vertices(pdg)
92  nuvtx = get_nu_vertices(pdg)
93  print 'Found', vtx.shape[0], 'hadronic vertices/decay', nuvtx.shape[0], 'neutrino vertices'
94 
95  for v in range(vtx.shape[0]):
96  flags = 0
97  if vtx.shape[0] > 0:
98  flags = vtx[v,2]
99 
100  nuflags = 0
101  if nuvtx.shape[0] > 0:
102  nuflags = nuvtx[v,2]
103 
104  if (flags & kHadr) > 0 or (flags & kDecay) > 0 or ((flags & kPi0) > 0 and (flags & kConv) > 0):
105 
106  wire = vtx[v,0]
107  drif = vtx[v,1]
108 
109  x_start = np.max([0, wire - PATCH_SIZE_W/2])
110  x_stop = np.min([raw.shape[0], x_start + PATCH_SIZE_W])
111 
112  y_start = np.max([0, drif - PATCH_SIZE_D/2])
113  y_stop = np.min([raw.shape[1], y_start + PATCH_SIZE_D])
114 
115  if x_stop - x_start != PATCH_SIZE_W or y_stop - y_start != PATCH_SIZE_D:
116  continue
117 
118  target = np.zeros(5, dtype=np.int32) # [decay, hadronic_vtx, g_conversion, nu_primary, not_vtx]
119  if nuflags > 0:
120  target[3] = 1
121  cnt_nu += 1
122  elif (flags & kDecay) > 0:
123  target[0] = 1
124  cnt_decay += 1
125  elif (flags & kHadr) > 0:
126  target[1] = 1
127  cnt_vtx += 1
128  elif (flags & kConv) > 0:
129  target[2] = 1
130  cnt_gamma += 1
131 
132  patch = get_patch(raw, wire, drif, PATCH_SIZE_W, PATCH_SIZE_D)
133  if cnt_ind < max_capacity:
134  db[cnt_ind] = patch
135  db_y[cnt_ind] = target
136  cnt_ind += 1
137  else: break
138 
139  n_empty = 0
140  n_trials = 0
141  while n_empty < nearby_empty and n_trials < 500:
142  wi = np.random.randint(x_start+1, x_stop-1)
143  di = np.random.randint(y_start+1, y_stop-1)
144  if (wi < wire-1 or wi > wire+1) and (di < drif-2 or di > drif+2):
145  if tracks[wi,di] == 0 and showers[wi,di] == 0:
146  if cnt_ind < max_capacity:
147  patch = get_patch(raw, wi, di, PATCH_SIZE_W, PATCH_SIZE_D)
148  target = np.zeros(5, dtype=np.int32)
149  target[4] = 1
150  db[cnt_ind] = patch
151  db_y[cnt_ind] = target
152  cnt_void += 1
153  cnt_ind += 1
154  n_empty += 1
155  else: break
156  n_trials += 1
157 
158  n_track = 0
159  n_trials = 0
160  while n_track < nearby_on_track and n_trials < 500:
161  wi = np.random.randint(x_start+1, x_stop-1)
162  di = np.random.randint(y_start+1, y_stop-1)
163  if (wi < wire-1 or wi > wire+1) and (di < drif-2 or di > drif+2):
164  if tracks[wi,di] == 1 or showers[wi,di] == 1:
165  if cnt_ind < max_capacity:
166  patch = get_patch(raw, wi, di, PATCH_SIZE_W, PATCH_SIZE_D)
167  target = np.zeros(5, dtype=np.int32)
168  target[4] = 1
169  db[cnt_ind] = patch
170  db_y[cnt_ind] = target
171  cnt_trk += 1
172  cnt_ind += 1
173  n_track += 1
174  else: break
175  n_trials += 1
176 
177  print 'Total size', cnt_ind, ':: hadronic:', cnt_vtx, 'decays:', cnt_decay, 'nu-primary:', cnt_nu, 'g-conv:', cnt_gamma, 'empty:', cnt_void, 'on-track:', cnt_trk
178 
179  np.save(OUTPUT_DIR+'/db_view_'+str(selected_view_idx)+'_x', db[:cnt_ind])
180  np.save(OUTPUT_DIR+'/db_view_'+str(selected_view_idx)+'_y', db_y[:cnt_ind])
181 
182 if __name__ == "__main__":
183  main(argv)
184 
def get_vertices(A)
Definition: utils.py:151
def get_nu_vertices(A)
Definition: utils.py:168
def get_patch(a, wire, drift, wsize, dsize)
Definition: utils.py:127
def read_config(cfgname)
Definition: utils.py:192
def get_data(folder, fname, drift_margin=0, crop=True, blur=None, white_noise=0, coherent_noise=0)
Definition: utils.py:33
static QCString str