train_cnn_continue.py
Go to the documentation of this file.
1 import argparse
2 parser = argparse.ArgumentParser(description='Run CNN training on patches with a few different hyperparameter sets.')
3 parser.add_argument('-c', '--config', help="JSON with script configuration", default='config.json')
4 parser.add_argument('-m', '--model', help="input CNN model name (saved in JSON and h5 files)", default='cnn_model')
5 parser.add_argument('-o', '--output', help="output CNN model name (saved in JSON and h5 files)", default='cnn_model_out')
6 parser.add_argument('-g', '--gpu', help="Which GPU index", default='0')
7 args = parser.parse_args()
8 
9 import os
10 os.environ['KERAS_BACKEND'] = "tensorflow"
11 os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
12 
13 import tensorflow as tf
14 import keras
15 if keras.__version__[0] != '2':
16  print 'Please use the newest Keras 2.x.x API with the Tensorflow backend'
17  quit()
18 keras.backend.set_image_data_format('channels_last')
19 keras.backend.set_image_dim_ordering('tf')
20 
21 import numpy as np
22 np.random.seed(2017) # for reproducibility
23 from keras.models import model_from_json
24 from keras.optimizers import SGD
25 from keras.utils import np_utils
26 from os.path import exists, isfile, join
27 import json
28 
29 from utils import read_config, get_patch_size, count_events, shuffle_in_place
30 
31 def load_model(name):
32  with open(name + '_architecture.json') as f:
33  model = model_from_json(f.read())
34  model.load_weights(name + '_weights.h5')
35  return model
36 
37 def save_model(model, name):
38  try:
39  with open(name + '_architecture.json', 'w') as f:
40  f.write(model.to_json())
41  model.save_weights(name + '_weights.h5', overwrite=True)
42  return True # Save successful
43  except:
44  return False # Save failed
45 
46 ####################### configuration #############################
47 print 'Reading configuration...'
48 
49 config = read_config(args.config)
50 
51 cfg_name = args.model
52 out_name = args.output
53 
54 CNN_INPUT_DIR = config['training_on_patches']['input_dir']
55 # input image dimensions
56 PATCH_SIZE_W, PATCH_SIZE_D = get_patch_size(CNN_INPUT_DIR)
57 img_rows, img_cols = PATCH_SIZE_W, PATCH_SIZE_D
58 
59 batch_size = config['training_on_patches']['batch_size']
60 nb_epoch = config['training_on_patches']['nb_epoch']
61 nb_classes = config['training_on_patches']['nb_classes']
62 
63 ###################### CNN commpilation ###########################
64 print 'Compiling CNN model...'
65 with tf.device('/gpu:' + args.gpu):
66  model = load_model(cfg_name)
67 
68  sgd = SGD(lr=0.005, decay=1e-5, momentum=0.9, nesterov=True)
69  model.compile(optimizer=sgd,
70  loss={'em_trk_none_netout': 'categorical_crossentropy', 'michel_netout': 'mean_squared_error'},
71  loss_weights={'em_trk_none_netout': 0.1, 'michel_netout': 1.})
72 
73 ####################### read data sets ############################
74 n_training = count_events(CNN_INPUT_DIR, 'training')
75 X_train = np.zeros((n_training, PATCH_SIZE_W, PATCH_SIZE_D, 1), dtype=np.float32)
76 EmTrkNone_train = np.zeros((n_training, 3), dtype=np.int32)
77 Michel_train = np.zeros((n_training, 1), dtype=np.int32)
78 print 'Training data size:', n_training, 'events; patch size:', PATCH_SIZE_W, 'x', PATCH_SIZE_D
79 
80 ntot = 0
81 subdirs = [f for f in os.listdir(CNN_INPUT_DIR) if 'training' in f]
82 subdirs.sort()
83 for dirname in subdirs:
84  print 'Reading data in', dirname
85  filesX = [f for f in os.listdir(CNN_INPUT_DIR + '/' + dirname) if '_x.npy' in f]
86  for fnameX in filesX:
87  print '...training data', fnameX
88  fnameY = fnameX.replace('_x.npy', '_y.npy')
89  dataX = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameX)
90  if dataX.dtype != np.dtype('float32'):
91  dataX = dataX.astype("float32")
92  dataY = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameY)
93  n = dataY.shape[0]
94  X_train[ntot:ntot+n] = dataX.reshape(n, img_rows, img_cols, 1)
95  EmTrkNone_train[ntot:ntot+n] = dataY[:,[0, 1, 3]]
96  Michel_train[ntot:ntot+n] = dataY[:,[2]]
97  ntot += n
98 print ntot, 'events ready'
99 
100 n_testing = count_events(CNN_INPUT_DIR, 'testing')
101 X_test = np.zeros((n_testing, PATCH_SIZE_W, PATCH_SIZE_D, 1), dtype=np.float32)
102 EmTrkNone_test = np.zeros((n_testing, 3), dtype=np.int32)
103 Michel_test = np.zeros((n_testing, 1), dtype=np.int32)
104 print 'Testing data size:', n_testing, 'events'
105 
106 ntot = 0
107 subdirs = [f for f in os.listdir(CNN_INPUT_DIR) if 'testing' in f]
108 subdirs.sort()
109 for dirname in subdirs:
110  print 'Reading data in', dirname
111  filesX = [f for f in os.listdir(CNN_INPUT_DIR + '/' + dirname) if '_x.npy' in f]
112  for fnameX in filesX:
113  print '...testing data', fnameX
114  fnameY = fnameX.replace('_x.npy', '_y.npy')
115  dataX = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameX)
116  if dataX.dtype != np.dtype('float32'):
117  dataX = dataX.astype("float32")
118  dataY = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameY)
119  n = dataY.shape[0]
120  X_test[ntot:ntot+n] = dataX.reshape(n, img_rows, img_cols, 1)
121  EmTrkNone_test[ntot:ntot+n] = dataY[:,[0, 1, 3]]
122  Michel_test[ntot:ntot+n] = dataY[:,[2]]
123  ntot += n
124 print ntot, 'events ready'
125 
126 dataX = None
127 dataY = None
128 
129 print 'Training', X_train.shape, 'testing', X_test.shape
130 
131 ########################## training ###############################
132 print 'Fit config:', cfg_name
133 h = model.fit({'main_input': X_train},
134  {'em_trk_none_netout': EmTrkNone_train, 'michel_netout': Michel_train},
135  validation_data=(
136  {'main_input': X_test},
137  {'em_trk_none_netout': EmTrkNone_test, 'michel_netout': Michel_test}),
138  batch_size=batch_size, epochs=nb_epoch, shuffle=True,
139  verbose=1)
140 
141 X_train = None
142 EmTrkNone_train = None
143 Michel_train = None
144 
145 score = model.evaluate({'main_input': X_test},
146  {'em_trk_none_netout': EmTrkNone_test, 'michel_netout': Michel_test},
147  verbose=0)
148 print('Test score:', score)
149 
150 X_test = None
151 EmTrkNone_test = None
152 Michel_test = None
153 #####################################################################
154 
155 print h.history['loss']
156 print h.history['val_loss']
157 
158 if save_model(model, args.output + cfg_name):
159  print('All done!')
160 else:
161  print('Error: model not saved.')
162 
int open(const char *, int)
Opens a file descriptor.
def count_events(folder, key)
Definition: utils.py:9
def read_config(cfgname)
Definition: utils.py:192
def get_patch_size(folder)
Definition: utils.py:20
def save_model(model, name)