run_cnn_3class.py
Go to the documentation of this file.
1 import argparse
2 parser = argparse.ArgumentParser(description='Run CNN over a full 2D projection.')
3 parser.add_argument('-i', '--input', help="Input file", default='datadump_hist.root') # '/eos/user/r/rosulej/ProtoDUNE/datadump/datadump_hist.root'
4 parser.add_argument('-e', '--event', help="Event index", default='0')
5 parser.add_argument('-m', '--module', help="LArSoft module name", default='datadump')
6 parser.add_argument('-f', '--full', help="Full 2D plane (1), or not-empty pixels only (0)", default='1')
7 parser.add_argument('-n', '--net', help="Network model name (json + h5 files)", default='model') # /eos/user/r/rosulej/models/pdune_em-trk-michel_clean_iter350
8 parser.add_argument('-g', '--gpu', help="GPU index to use (default is CPU)", default='-1')
9 parser.add_argument('-r', '--rows', help="Patch rows (wires)", default='44')
10 parser.add_argument('-c', '--cols', help="Patch cols (ticks)", default='48')
11 args = parser.parse_args()
12 
13 import numpy as np
14 import matplotlib.pyplot as plt
15 from ROOT import TFile
16 from utils import get_data, get_patch
17 
18 import theano
19 import theano.sandbox.cuda
20 if args.gpu == '-1':
21  theano.sandbox.cuda.use('cpu')
22  print 'Running on CPU, use -g option to say which device index should be used.'
23 else:
24  theano.sandbox.cuda.use('gpu' + args.gpu)
25 
26 import os
27 os.environ['KERAS_BACKEND'] = "theano"
28 
29 import keras
30 from keras.models import model_from_json
31 
32 print 'Software versions: Theano ', theano.__version__, ', Keras ', keras.__version__
33 if keras.backend.backend() != 'theano':
34  print '**** You should be using Theano backend now...****'
35  quit()
36 keras.backend.set_image_dim_ordering('th')
37 
38 def load_model(name):
39  with open(name + '_architecture.json') as f:
40  model = model_from_json(f.read())
41  model.load_weights(name + '_weights.h5')
42  return model
43 
44 
45 PATCH_SIZE_W = int(args.rows) # wires
46 PATCH_SIZE_D = int(args.cols) # ticks (downsampled)
47 crop_event = False
48 
49 rootModule = args.module
50 rootFile = TFile(args.input)
51 keys = [rootModule+'/'+k.GetName()[:-4] for k in rootFile.Get(rootModule).GetListOfKeys() if '_raw' in k.GetName()]
52 evname = keys[int(args.event)]
53 
54 raw, deposit, pdg, tracks, showers = get_data(rootFile, evname, PATCH_SIZE_D/2 + 2, crop_event)
55 full2d = int(args.full)
56 if full2d == 1: total_patches = raw.size
57 else: total_patches = int(np.sum(tracks) + np.sum(showers))
58 print 'Number of pixels:', total_patches
59 
60 inputs = np.zeros((total_patches, PATCH_SIZE_W, PATCH_SIZE_D), dtype=np.float32)
61 
62 cnt_ind = 0
63 for r in range(raw.shape[0]):
64  for c in range(raw.shape[1]):
65  if full2d == 0 and not(tracks[r, c] == 1 or showers[r, c] == 1):
66  continue
67 
68  inputs[cnt_ind] = get_patch(raw, r, c, PATCH_SIZE_W, PATCH_SIZE_D)
69  cnt_ind += 1
70 
71 inputs = inputs[:(cnt_ind)]
72 print inputs.shape, cnt_ind
73 
74 model_name = args.net
75 m = load_model(model_name)
76 m.compile(loss='mean_squared_error', optimizer='sgd')
77 
78 print 'running CNN...'
79 pred = m.predict(inputs.reshape(inputs.shape[0], 1, PATCH_SIZE_W, PATCH_SIZE_D)) # nsamples, channel, rows, cols
80 print '...done', pred.shape
81 
82 outputs = np.zeros((pred.shape[1], raw.shape[0], raw.shape[1]), dtype=np.float32)
83 
84 mask = np.zeros(raw.shape, dtype=np.int32)
85 mask_thr = 0.67
86 
87 none_idx = pred.shape[1] - 1 # here is "none" label usually
88 pnorm = pred[:, 0] + pred[:, 1] + pred[:, none_idx]
89 
90 cnt_ind = 0
91 for r in range(outputs.shape[1]):
92  for c in range(outputs.shape[2]):
93  if full2d == 0 and not(tracks[r, c] == 1 or showers[r, c] == 1):
94  continue
95 
96  if pnorm[cnt_ind] > 0:
97  pn = 1.0 / pnorm[cnt_ind]
98  outputs[0, r, c] = pred[cnt_ind, 0] * pn
99  outputs[1, r, c] = pred[cnt_ind, 1] * pn
100 
101  outputs[none_idx, r, c] = 1 - pred[cnt_ind, none_idx] * pn
102  if outputs[none_idx, r, c] > mask_thr:
103  if outputs[1, r, c] > outputs[0, r, c]: mask[r, c] = 1
104  else: mask[r, c] = -1
105 
106  cnt_ind += 1
107 
108 
109 fig, ax = plt.subplots(2, 3, figsize=(28, 14))
110 
111 cs = ax[0,0].pcolor(np.transpose(pdg & 0xFF), cmap='gist_ncar')
112 ax[0,0].set_title('PDG')
113 fig.colorbar(cs, ax=ax[0,0])
114 
115 cs = ax[0,1].pcolor(np.transpose(deposit), cmap='jet')
116 ax[0,1].set_title('MC truth')
117 fig.colorbar(cs, ax=ax[0,1])
118 
119 cs = ax[0,2].pcolor(np.transpose(raw), cmap='jet')
120 ax[0,2].set_title('ADC')
121 fig.colorbar(cs, ax=ax[0,2])
122 
123 cs = ax[1,0].pcolor(-np.transpose(outputs[0]), cmap='CMRmap')
124 ax[1,0].set_title('P(track-like)')
125 fig.colorbar(cs, ax=ax[1,0])
126 
127 cs = ax[1,1].pcolor(-np.transpose(outputs[1]), cmap='CMRmap')
128 ax[1,1].set_title('P(EM-like)')
129 fig.colorbar(cs, ax=ax[1,1])
130 
131 #cs = ax[1,2].contour(np.transpose(outputs[none_idx]), levels=[0, 0.33, 0.67, 1], colors=('black' , (0.6, 0.6, 1), (0.9, 0, 0), 'green'))
132 #ax[1,2].set_title('1 - P(empty)')
133 cs = ax[1,2].pcolor(np.transpose(mask), cmap='RdYlGn')
134 ax[1,2].set_title('ROI')
135 fig.colorbar(cs, ax=ax[1,2])
136 
137 #plt.subplots_adjust(left=0.03, right=0.999, bottom=0.07, top=0.93)
138 plt.tight_layout()
139 plt.show()
def load_model(name)
int open(const char *, int)
Opens a file descriptor.
def get_patch(a, wire, drift, wsize, dsize)
Definition: utils.py:127
def get_data(folder, fname, drift_margin=0, crop=True, blur=None, white_noise=0, coherent_noise=0)
Definition: utils.py:33