run_cnn_1class.py
Go to the documentation of this file.
1 import argparse
2 parser = argparse.ArgumentParser(description='Run CNN over a full 2D projection.')
3 parser.add_argument('-i', '--input', help="Input file", default='datadump_hist.root') # '/eos/user/r/rosulej/ProtoDUNE/datadump/datadump_hist.root'
4 parser.add_argument('-e', '--event', help="Event index", default='0')
5 parser.add_argument('-m', '--module', help="LArSoft module name", default='datadump')
6 parser.add_argument('-f', '--full', help="Full 2D plane (1), or not-empty pixels only (0)", default='1')
7 parser.add_argument('-n', '--net', help="Network model name (json + h5 files)", default='model') # /eos/user/r/rosulej/models/pdune_em-trk-michel_clean_iter350
8 parser.add_argument('-g', '--gpu', help="GPU index to use (default is CPU)", default='-1')
9 parser.add_argument('-r', '--rows', help="Patch rows (wires)", default='44')
10 parser.add_argument('-c', '--cols', help="Patch cols (ticks)", default='48')
11 args = parser.parse_args()
12 
13 import numpy as np
14 import matplotlib.pyplot as plt
15 from ROOT import TFile
16 from utils import get_data, get_patch
17 
18 import theano
19 import theano.sandbox.cuda
20 if args.gpu == '-1':
21  theano.sandbox.cuda.use('cpu')
22  print 'Running on CPU, use -g option to say which device index should be used.'
23 else:
24  theano.sandbox.cuda.use('gpu' + args.gpu)
25 
26 import os
27 os.environ['KERAS_BACKEND'] = "theano"
28 
29 import keras
30 from keras.models import model_from_json
31 
32 print 'Software versions: Theano ', theano.__version__, ', Keras ', keras.__version__
33 if keras.backend.backend() != 'theano':
34  print '**** You should be using Theano backend now...****'
35  quit()
36 keras.backend.set_image_dim_ordering('th')
37 
38 def load_model(name):
39  with open(name + '_architecture.json') as f:
40  model = model_from_json(f.read())
41  model.load_weights(name + '_weights.h5')
42  return model
43 
44 
45 PATCH_SIZE_W = int(args.rows) # wires
46 PATCH_SIZE_D = int(args.cols) # ticks (downsampled)
47 crop_event = False
48 
49 rootModule = args.module
50 rootFile = TFile(args.input)
51 keys = [rootModule+'/'+k.GetName()[:-4] for k in rootFile.Get(rootModule).GetListOfKeys() if '_raw' in k.GetName()]
52 evname = keys[int(args.event)]
53 
54 raw, deposit, pdg, tracks, showers = get_data(rootFile, evname, PATCH_SIZE_D/2 + 2, crop_event)
55 full2d = int(args.full)
56 if full2d == 1: total_patches = raw.size
57 else: total_patches = int(np.sum(tracks) + np.sum(showers))
58 print 'Number of pixels:', total_patches
59 
60 inputs = np.zeros((total_patches, PATCH_SIZE_W, PATCH_SIZE_D), dtype=np.float32)
61 
62 cnt_ind = 0
63 for r in range(raw.shape[0]):
64  for c in range(raw.shape[1]):
65  if full2d == 0 and not(tracks[r, c] == 1 or showers[r, c] == 1):
66  continue
67 
68  inputs[cnt_ind] = get_patch(raw, r, c, PATCH_SIZE_W, PATCH_SIZE_D)
69  cnt_ind += 1
70 
71 inputs = inputs[:(cnt_ind)]
72 print inputs.shape, cnt_ind
73 
74 model_name = args.net
75 m = load_model(model_name)
76 m.compile(loss='mean_squared_error', optimizer='sgd')
77 
78 print 'running CNN...'
79 pred = m.predict(inputs.reshape(inputs.shape[0], 1, PATCH_SIZE_W, PATCH_SIZE_D)) # nsamples, channel, rows, cols
80 if len(pred.shape) > 1 and pred.shape[1] > 1: pred = pred[:,0] # select output 0 if multiple outputs
81 pred.flatten()
82 print '...done'
83 
84 outputs = np.zeros((raw.shape[0], raw.shape[1]), dtype=np.float32)
85 
86 cnt_ind = 0
87 for r in range(outputs.shape[0]):
88  for c in range(outputs.shape[1]):
89  if full2d == 0 and not(tracks[r, c] == 1 or showers[r, c] == 1):
90  outputs[r, c] = -1
91  continue
92 
93  outputs[r, c] = pred[cnt_ind]
94  cnt_ind += 1
95 
96 # np.save('cnn_output.npy', outputs)
97 
98 fig, ax = plt.subplots(2, 2, figsize=(17, 14))
99 
100 cs = ax[0,0].pcolor(np.transpose(pdg & 0xFF), cmap='gist_ncar')
101 ax[0,0].set_title('PDG')
102 fig.colorbar(cs, ax=ax[0,0])
103 
104 cs = ax[0,1].pcolor(np.transpose(deposit), cmap='jet')
105 ax[0,1].set_title('MC truth deposit')
106 fig.colorbar(cs, ax=ax[0,1])
107 
108 cs = ax[1,0].pcolor(np.transpose(raw), cmap='jet')
109 ax[1,0].set_title('ADC')
110 fig.colorbar(cs, ax=ax[1,0])
111 
112 cs = ax[1,1].pcolor(np.transpose(outputs), cmap='CMRmap')
113 ax[1,1].set_title('CNN output')
114 fig.colorbar(cs, ax=ax[1,1])
115 
116 plt.tight_layout()
117 plt.show()
int open(const char *, int)
Opens a file descriptor.
def load_model(name)
def get_patch(a, wire, drift, wsize, dsize)
Definition: utils.py:127
def get_data(folder, fname, drift_margin=0, crop=True, blur=None, white_noise=0, coherent_noise=0)
Definition: utils.py:33