train_cnn_continue_augmented_data.py
Go to the documentation of this file.
1 import argparse
2 parser = argparse.ArgumentParser(description='Run CNN training on patches with a few different hyperparameter sets.')
3 parser.add_argument('-c', '--config', help="JSON with script configuration", default='config.json')
4 parser.add_argument('-m', '--model', help="input CNN model name (saved in JSON and h5 files)", default='cnn_model')
5 parser.add_argument('-o', '--output', help="output CNN model name (saved in JSON and h5 files)", default='cnn_model_out')
6 parser.add_argument('-g', '--gpu', help="Which GPU index", default='0')
7 args = parser.parse_args()
8 
9 import os
10 os.environ['KERAS_BACKEND'] = "tensorflow"
11 os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
12 
13 import tensorflow as tf
14 import keras
15 if keras.__version__[0] != '2':
16  print 'Please use the newest Keras 2.x.x API with the Tensorflow backend'
17  quit()
18 keras.backend.set_image_data_format('channels_last')
19 keras.backend.set_image_dim_ordering('tf')
20 
21 import numpy as np
22 np.random.seed(2017) # for reproducibility
23 from keras.preprocessing.image import ImageDataGenerator
24 from keras.models import model_from_json
25 from keras.optimizers import SGD
26 from keras.utils import np_utils
27 from os.path import exists, isfile, join
28 import json
29 
30 from utils import read_config, get_patch_size, count_events, shuffle_in_place
31 
32 def load_model(name):
33  with open(name + '_architecture.json') as f:
34  model = model_from_json(f.read())
35  model.load_weights(name + '_weights.h5')
36  return model
37 
38 def save_model(model, name):
39  try:
40  with open(name + '_architecture.json', 'w') as f:
41  f.write(model.to_json())
42  model.save_weights(name + '_weights.h5', overwrite=True)
43  return True # Save successful
44  except:
45  return False # Save failed
46 
47 ####################### configuration #############################
48 print 'Reading configuration...'
49 
50 config = read_config(args.config)
51 
52 cfg_name = args.model
53 out_name = args.output
54 
55 CNN_INPUT_DIR = config['training_on_patches']['input_dir']
56 # input image dimensions
57 PATCH_SIZE_W, PATCH_SIZE_D = get_patch_size(CNN_INPUT_DIR)
58 img_rows, img_cols = PATCH_SIZE_W, PATCH_SIZE_D
59 
60 batch_size = config['training_on_patches']['batch_size']
61 nb_epoch = config['training_on_patches']['nb_epoch']
62 nb_classes = config['training_on_patches']['nb_classes']
63 
64 ###################### CNN commpilation ###########################
65 print 'Compiling CNN model...'
66 with tf.device('/gpu:' + args.gpu):
67  model = load_model(cfg_name)
68 
69  sgd = SGD(lr=0.002, decay=1e-5, momentum=0.9, nesterov=True)
70  model.compile(optimizer=sgd,
71  loss={'em_trk_none_netout': 'categorical_crossentropy', 'michel_netout': 'mean_squared_error'},
72  loss_weights={'em_trk_none_netout': 0.1, 'michel_netout': 1.0})
73 
74 ####################### read data sets ############################
75 n_training = count_events(CNN_INPUT_DIR, 'training')
76 X_train = np.zeros((n_training, PATCH_SIZE_W, PATCH_SIZE_D, 1), dtype=np.float32)
77 EmTrkNone_train = np.zeros((n_training, 3), dtype=np.int32)
78 Michel_train = np.zeros((n_training, 1), dtype=np.int32)
79 print 'Training data size:', n_training, 'events; patch size:', PATCH_SIZE_W, 'x', PATCH_SIZE_D
80 
81 ntot = 0
82 subdirs = [f for f in os.listdir(CNN_INPUT_DIR) if 'training' in f]
83 subdirs.sort()
84 for dirname in subdirs:
85  print 'Reading data in', dirname
86  filesX = [f for f in os.listdir(CNN_INPUT_DIR + '/' + dirname) if '_x.npy' in f]
87  for fnameX in filesX:
88  print '...training data', fnameX
89  fnameY = fnameX.replace('_x.npy', '_y.npy')
90  dataX = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameX)
91  if dataX.dtype != np.dtype('float32'):
92  dataX = dataX.astype("float32")
93  dataY = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameY)
94  n = dataY.shape[0]
95  X_train[ntot:ntot+n] = dataX.reshape(n, img_rows, img_cols, 1)
96  EmTrkNone_train[ntot:ntot+n] = dataY[:,[0, 1, 3]]
97  Michel_train[ntot:ntot+n] = dataY[:,[2]]
98  ntot += n
99 print ntot, 'events ready'
100 
101 n_testing = count_events(CNN_INPUT_DIR, 'testing')
102 X_test = np.zeros((n_testing, PATCH_SIZE_W, PATCH_SIZE_D, 1), dtype=np.float32)
103 EmTrkNone_test = np.zeros((n_testing, 3), dtype=np.int32)
104 Michel_test = np.zeros((n_testing, 1), dtype=np.int32)
105 print 'Testing data size:', n_testing, 'events'
106 
107 ntot = 0
108 subdirs = [f for f in os.listdir(CNN_INPUT_DIR) if 'testing' in f]
109 subdirs.sort()
110 for dirname in subdirs:
111  print 'Reading data in', dirname
112  filesX = [f for f in os.listdir(CNN_INPUT_DIR + '/' + dirname) if '_x.npy' in f]
113  for fnameX in filesX:
114  print '...testing data', fnameX
115  fnameY = fnameX.replace('_x.npy', '_y.npy')
116  dataX = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameX)
117  if dataX.dtype != np.dtype('float32'):
118  dataX = dataX.astype("float32")
119  dataY = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameY)
120  n = dataY.shape[0]
121  X_test[ntot:ntot+n] = dataX.reshape(n, img_rows, img_cols, 1)
122  EmTrkNone_test[ntot:ntot+n] = dataY[:,[0, 1, 3]]
123  Michel_test[ntot:ntot+n] = dataY[:,[2]]
124  ntot += n
125 print ntot, 'events ready'
126 
127 dataX = None
128 dataY = None
129 
130 print 'Training', X_train.shape, 'testing', X_test.shape
131 
132 ########################## training ###############################
133 datagen = ImageDataGenerator(
134  featurewise_center=False, samplewise_center=False,
135  featurewise_std_normalization=False,
136  samplewise_std_normalization=False,
137  zca_whitening=False,
138  rotation_range=0, width_shift_range=0, height_shift_range=0,
139  horizontal_flip=True, # randomly flip images
140  vertical_flip=False) # only horizontal flip
141 datagen.fit(X_train)
142 
143 def generate_data_generator(generator, X, Y1, Y2, b):
144  genY1 = generator.flow(X, Y1, batch_size=b, seed=7)
145  genY2 = generator.flow(X, Y2, batch_size=b, seed=7)
146  while True:
147  g1 = genY1.next()
148  g2 = genY2.next()
149  yield {'main_input': g1[0]}, {'em_trk_none_netout': g1[1], 'michel_netout': g2[1]}
150 
151 print 'Fit config:', cfg_name
152 h = model.fit_generator(
153  generate_data_generator(datagen, X_train, EmTrkNone_train, Michel_train, b=batch_size),
154  validation_data=(
155  {'main_input': X_test},
156  {'em_trk_none_netout': EmTrkNone_test, 'michel_netout': Michel_test}),
157  steps_per_epoch=X_train.shape[0]/batch_size, epochs=nb_epoch,
158  verbose=1)
159 
160 X_train = None
161 EmTrkNone_train = None
162 Michel_train = None
163 
164 score = model.evaluate({'main_input': X_test},
165  {'em_trk_none_netout': EmTrkNone_test, 'michel_netout': Michel_test},
166  verbose=0)
167 print('Test score:', score)
168 
169 X_test = None
170 EmTrkNone_test = None
171 Michel_test = None
172 #####################################################################
173 
174 print h.history['loss']
175 print h.history['val_loss']
176 
177 if save_model(model, args.output + cfg_name):
178  print('All done!')
179 else:
180  print('Error: model not saved.')
181 
int open(const char *, int)
Opens a file descriptor.
def count_events(folder, key)
Definition: utils.py:9
def get_patch_size(folder)
Definition: utils.py:20
def read_config(cfgname)
Definition: utils.py:192