train_cnn.py
Go to the documentation of this file.
1 import argparse
2 parser = argparse.ArgumentParser(description='Run CNN training on patches with a few different hyperparameter sets.')
3 parser.add_argument('-c', '--config', help="JSON with script configuration", default='config.json')
4 parser.add_argument('-o', '--output', help="Output model file name", default='model')
5 parser.add_argument('-g', '--gpu', help="Which GPU index", default='0')
6 args = parser.parse_args()
7 
8 import os
9 os.environ['KERAS_BACKEND'] = "tensorflow"
10 os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
11 
12 import tensorflow as tf
13 import keras
14 if keras.__version__[0] != '2':
15  print 'Please use the newest Keras 2.x.x API with the Tensorflow backend'
16  quit()
17 keras.backend.set_image_data_format('channels_last')
18 keras.backend.set_image_dim_ordering('tf')
19 
20 import numpy as np
21 np.random.seed(2017) # for reproducibility
22 from keras.models import Model
23 from keras.layers import Input
24 from keras.layers.core import Dense, Dropout, Activation, Flatten
25 from keras.layers.convolutional import Conv2D, MaxPooling2D
26 from keras.layers.advanced_activations import LeakyReLU
27 # from keras.layers.normalization import BatchNormalization
28 from keras.optimizers import SGD
29 from keras.utils import np_utils
30 from os.path import exists, isfile, join
31 import json
32 
33 from utils import read_config, get_patch_size, count_events
34 
35 def save_model(model, name):
36  try:
37  with open(name + '_architecture.json', 'w') as f:
38  f.write(model.to_json())
39  model.save_weights(name + '_weights.h5', overwrite=True)
40  return True # Save successful
41  except:
42  return False # Save failed
43 
44 ####################### configuration #############################
45 print 'Reading configuration...'
46 config = read_config(args.config)
47 
48 CNN_INPUT_DIR = config['training_on_patches']['input_dir']
49 # input image dimensions
50 PATCH_SIZE_W, PATCH_SIZE_D = get_patch_size(CNN_INPUT_DIR)
51 img_rows, img_cols = PATCH_SIZE_W, PATCH_SIZE_D
52 
53 batch_size = config['training_on_patches']['batch_size']
54 nb_classes = config['training_on_patches']['nb_classes']
55 nb_epoch = config['training_on_patches']['nb_epoch']
56 
57 nb_pool = 2 # size of pooling area for max pooling
58 
59 cfg_name = 'sgd_lorate'
60 
61 # convolutional layers:
62 nb_filters1 = 48 # number of convolutional filters in the first layer
63 nb_conv1 = 5 # 1st convolution kernel size
64 convactfn1 = 'relu'
65 
66 maxpool = False # max pooling between conv. layers
67 
68 nb_filters2 = 0 # number of convolutional filters in the second layer
69 nb_conv2 = 7 # convolution kernel size
70 convactfn2 = 'relu'
71 
72 drop1 = 0.2
73 
74 # dense layers:
75 densesize1 = 128
76 actfn1 = 'relu'
77 densesize2 = 32
78 actfn2 = 'relu'
79 drop2 = 0.2
80 
81 ####################### CNN definition ############################
82 print 'Compiling CNN model...'
83 with tf.device('/gpu:' + args.gpu):
84  main_input = Input(shape=(img_rows, img_cols, 1), name='main_input')
85 
86  if convactfn1 == 'leaky':
87  x = Conv2D(nb_filters1, (nb_conv1, nb_conv1),
88  padding='valid', data_format='channels_last',
89  activation=LeakyReLU())(main_input)
90  else:
91  x = Conv2D(nb_filters1, (nb_conv1, nb_conv1),
92  padding='valid', data_format='channels_last',
93  activation=convactfn1)(main_input)
94 
95  if nb_filters2 > 0:
96  if maxpool:
97  x = MaxPooling2D(pool_size=(nb_pool, nb_pool))(x)
98  x = Conv2D(nb_filters2, (nb_conv2, nb_conv2))(x)
99  if convactfn2 == 'leaky':
100  x = Conv2D(nb_filters2, (nb_conv2, nb_conv2), activation=LeakyReLU())(x)
101  else:
102  x = Conv2D(nb_filters2, (nb_conv2, nb_conv2), activation=convactfn2)(x)
103 
104  x = Dropout(drop1)(x)
105  x = Flatten()(x)
106  # x = BatchNormalization()(x)
107 
108  # dense layers
109  x = Dense(densesize1, activation=actfn1)(x)
110  x = Dropout(drop2)(x)
111 
112  if densesize2 > 0:
113  x = Dense(densesize2, activation=actfn2)(x)
114  x = Dropout(drop2)(x)
115 
116  # outputs
117  em_trk_none = Dense(3, activation='softmax', name='em_trk_none_netout')(x)
118  michel = Dense(1, activation='sigmoid', name='michel_netout')(x)
119 
120  sgd = SGD(lr=0.01, decay=1e-5, momentum=0.9, nesterov=True)
121  model = Model(inputs=[main_input], outputs=[em_trk_none, michel])
122  model.compile(optimizer=sgd,
123  loss={'em_trk_none_netout': 'categorical_crossentropy', 'michel_netout': 'mean_squared_error'},
124  loss_weights={'em_trk_none_netout': 0.1, 'michel_netout': 1.})
125 
126 ####################### read data sets ############################
127 n_training = count_events(CNN_INPUT_DIR, 'training')
128 X_train = np.zeros((n_training, PATCH_SIZE_W, PATCH_SIZE_D, 1), dtype=np.float32)
129 EmTrkNone_train = np.zeros((n_training, 3), dtype=np.int32)
130 Michel_train = np.zeros((n_training, 1), dtype=np.int32)
131 print 'Training data size:', n_training, 'events; patch size:', PATCH_SIZE_W, 'x', PATCH_SIZE_D
132 
133 ntot = 0
134 subdirs = [f for f in os.listdir(CNN_INPUT_DIR) if 'training' in f]
135 subdirs.sort()
136 for dirname in subdirs:
137  print 'Reading data in', dirname
138  filesX = [f for f in os.listdir(CNN_INPUT_DIR + '/' + dirname) if '_x.npy' in f]
139  for fnameX in filesX:
140  print '...training data', fnameX
141  fnameY = fnameX.replace('_x.npy', '_y.npy')
142  dataX = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameX)
143  if dataX.dtype != np.dtype('float32'):
144  dataX = dataX.astype("float32")
145  dataY = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameY)
146  n = dataY.shape[0]
147  X_train[ntot:ntot+n] = dataX.reshape(n, img_rows, img_cols, 1)
148  EmTrkNone_train[ntot:ntot+n] = dataY[:,[0, 1, 3]]
149  Michel_train[ntot:ntot+n] = dataY[:,[2]]
150  ntot += n
151 print ntot, 'events ready'
152 
153 n_testing = count_events(CNN_INPUT_DIR, 'testing')
154 X_test = np.zeros((n_testing, PATCH_SIZE_W, PATCH_SIZE_D, 1), dtype=np.float32)
155 EmTrkNone_test = np.zeros((n_testing, 3), dtype=np.int32)
156 Michel_test = np.zeros((n_testing, 1), dtype=np.int32)
157 print 'Testing data size:', n_testing, 'events'
158 
159 ntot = 0
160 subdirs = [f for f in os.listdir(CNN_INPUT_DIR) if 'testing' in f]
161 subdirs.sort()
162 for dirname in subdirs:
163  print 'Reading data in', dirname
164  filesX = [f for f in os.listdir(CNN_INPUT_DIR + '/' + dirname) if '_x.npy' in f]
165  for fnameX in filesX:
166  print '...testing data', fnameX
167  fnameY = fnameX.replace('_x.npy', '_y.npy')
168  dataX = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameX)
169  if dataX.dtype != np.dtype('float32'):
170  dataX = dataX.astype("float32")
171  dataY = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameY)
172  n = dataY.shape[0]
173  X_test[ntot:ntot+n] = dataX.reshape(n, img_rows, img_cols, 1)
174  EmTrkNone_test[ntot:ntot+n] = dataY[:,[0, 1, 3]]
175  Michel_test[ntot:ntot+n] = dataY[:,[2]]
176  ntot += n
177 print ntot, 'events ready'
178 
179 dataX = None
180 dataY = None
181 
182 print 'Training', X_train.shape, 'testing', X_test.shape
183 
184 ########################## training ###############################
185 print 'Fit config:', cfg_name
186 h = model.fit({'main_input': X_train},
187  {'em_trk_none_netout': EmTrkNone_train, 'michel_netout': Michel_train},
188  validation_data=(
189  {'main_input': X_test},
190  {'em_trk_none_netout': EmTrkNone_test, 'michel_netout': Michel_test}),
191  batch_size=batch_size, epochs=nb_epoch, shuffle=True,
192  verbose=1)
193 
194 X_train = None
195 EmTrkNone_train = None
196 Michel_train = None
197 
198 score = model.evaluate({'main_input': X_test},
199  {'em_trk_none_netout': EmTrkNone_test, 'michel_netout': Michel_test},
200  verbose=0)
201 print('Test score:', score)
202 
203 X_test = None
204 EmTrkNone_test = None
205 Michel_test = None
206 #####################################################################
207 
208 print h.history['loss']
209 print h.history['val_loss']
210 
211 if save_model(model, args.output + cfg_name):
212  print('All done!')
213 else:
214  print('Error: model not saved.')
215 
int open(const char *, int)
Opens a file descriptor.
def count_events(folder, key)
Definition: utils.py:9
def read_config(cfgname)
Definition: utils.py:192
def get_patch_size(folder)
Definition: utils.py:20
def save_model(model, name)
Definition: train_cnn.py:35