5 from os.path
import isfile, join
9 from utils
import read_config, get_data, get_patch, get_vertices, get_nu_vertices
13 parser = argparse.ArgumentParser(description=
'Makes training data set for various vertex/decay ID')
14 parser.add_argument(
'-c',
'--config', help=
"JSON with script configuration", default=
'config.json')
15 parser.add_argument(
'-t',
'--type', help=
"Input file format")
16 parser.add_argument(
'-i',
'--input', help=
"Input directory")
17 parser.add_argument(
'-o',
'--output', help=
"Output directory")
18 parser.add_argument(
'-v',
'--view', help=
"view")
19 args = parser.parse_args()
23 print '#'*50,
'\nPrepare data for CNN' 24 INPUT_TYPE = config[
'prepare_data_vtx_id'][
'input_type']
25 INPUT_DIR = config[
'prepare_data_vtx_id'][
'input_dir']
26 OUTPUT_DIR = config[
'prepare_data_vtx_id'][
'output_dir']
27 PATCH_SIZE_W = config[
'prepare_data_vtx_id'][
'patch_size_w']
28 PATCH_SIZE_D = config[
'prepare_data_vtx_id'][
'patch_size_d']
29 print 'Using %s as input dir, and %s as output dir' % (INPUT_DIR, OUTPUT_DIR)
32 rootModule = config[
'prepare_data_vtx_id'][
'module_name']
33 selected_view_idx = config[
'prepare_data_vtx_id'][
'selected_view_idx']
34 nearby_empty = config[
'prepare_data_vtx_id'][
'nearby_empty']
35 nearby_on_track = config[
'prepare_data_vtx_id'][
'nearby_on_track']
36 crop_event = config[
'prepare_data_vtx_id'][
'crop_event']
38 print 'Using', nearby_empty,
'empty and', nearby_on_track,
'on track patches per each verex in view', selected_view_idx
41 db = np.zeros((max_capacity, PATCH_SIZE_W, PATCH_SIZE_D), dtype=np.float32)
42 db_y = np.zeros((max_capacity, 5), dtype=np.int32)
60 if INPUT_TYPE ==
"root":
61 fnames = [f
for f
in os.listdir(INPUT_DIR)
if '.root' in f]
63 rootFile = TFile(INPUT_DIR+
'/'+n)
64 keys = [rootModule+
'/'+k.GetName()[:-4]
for k
in rootFile.Get(rootModule).GetListOfKeys()
if '_raw' in k.GetName()]
65 event_list.append((rootFile, keys))
67 keys = [f[:-4]
for f
in os.listdir(INPUT_DIR)
if '.raw' in f]
68 event_list.append((INPUT_DIR, keys))
70 for entry
in event_list:
72 event_names = entry[1]
74 for evname
in event_names:
75 finfo = evname.split(
'_')
77 tpc_idx =
int(finfo[8])
78 view_idx =
int(finfo[10])
80 if view_idx != selected_view_idx:
continue 83 print 'Process event', fcount, evname,
'NO.', evt_no
86 raw, deposit, pdg, tracks, showers =
get_data(folder, evname, PATCH_SIZE_D/2 + 2, crop_event)
88 print 'Skip empty event...' 93 print 'Found', vtx.shape[0],
'hadronic vertices/decay', nuvtx.shape[0],
'neutrino vertices' 95 for v
in range(vtx.shape[0]):
101 if nuvtx.shape[0] > 0:
104 if (flags & kHadr) > 0
or (flags & kDecay) > 0
or ((flags & kPi0) > 0
and (flags & kConv) > 0):
109 x_start = np.max([0, wire - PATCH_SIZE_W/2])
110 x_stop = np.min([raw.shape[0], x_start + PATCH_SIZE_W])
112 y_start = np.max([0, drif - PATCH_SIZE_D/2])
113 y_stop = np.min([raw.shape[1], y_start + PATCH_SIZE_D])
115 if x_stop - x_start != PATCH_SIZE_W
or y_stop - y_start != PATCH_SIZE_D:
118 target = np.zeros(5, dtype=np.int32)
122 elif (flags & kDecay) > 0:
125 elif (flags & kHadr) > 0:
128 elif (flags & kConv) > 0:
132 patch =
get_patch(raw, wire, drif, PATCH_SIZE_W, PATCH_SIZE_D)
133 if cnt_ind < max_capacity:
135 db_y[cnt_ind] = target
141 while n_empty < nearby_empty
and n_trials < 500:
142 wi = np.random.randint(x_start+1, x_stop-1)
143 di = np.random.randint(y_start+1, y_stop-1)
144 if (wi < wire-1
or wi > wire+1)
and (di < drif-2
or di > drif+2):
145 if tracks[wi,di] == 0
and showers[wi,di] == 0:
146 if cnt_ind < max_capacity:
147 patch =
get_patch(raw, wi, di, PATCH_SIZE_W, PATCH_SIZE_D)
148 target = np.zeros(5, dtype=np.int32)
151 db_y[cnt_ind] = target
160 while n_track < nearby_on_track
and n_trials < 500:
161 wi = np.random.randint(x_start+1, x_stop-1)
162 di = np.random.randint(y_start+1, y_stop-1)
163 if (wi < wire-1
or wi > wire+1)
and (di < drif-2
or di > drif+2):
164 if tracks[wi,di] == 1
or showers[wi,di] == 1:
165 if cnt_ind < max_capacity:
166 patch =
get_patch(raw, wi, di, PATCH_SIZE_W, PATCH_SIZE_D)
167 target = np.zeros(5, dtype=np.int32)
170 db_y[cnt_ind] = target
177 print 'Total size', cnt_ind,
':: hadronic:', cnt_vtx,
'decays:', cnt_decay,
'nu-primary:', cnt_nu,
'g-conv:', cnt_gamma,
'empty:', cnt_void,
'on-track:', cnt_trk
179 np.save(OUTPUT_DIR+
'/db_view_'+
str(selected_view_idx)+
'_x', db[:cnt_ind])
180 np.save(OUTPUT_DIR+
'/db_view_'+
str(selected_view_idx)+
'_y', db_y[:cnt_ind])
182 if __name__ ==
"__main__":
def get_patch(a, wire, drift, wsize, dsize)
def get_data(folder, fname, drift_margin=0, crop=True, blur=None, white_noise=0, coherent_noise=0)