training.py
Go to the documentation of this file.
1 import numpy as np
2 import pickle
3 import configparser
4 import ast
5 import logging, sys
6 import re
7 import os
8 
9 from sklearn.utils import class_weight
10 from keras.models import Model, Sequential, load_model
11 from keras.layers import Input, Dense, Activation, ZeroPadding2D, Dropout, Flatten, BatchNormalization, SeparableConv2D
12 from keras import regularizers, optimizers
13 from keras.layers.convolutional import Conv2D, MaxPooling2D, AveragePooling2D
14 from keras.callbacks import LearningRateScheduler, ReduceLROnPlateau, CSVLogger, ModelCheckpoint, EarlyStopping
15 from data_generator import DataGenerator
16 from collections import Counter
17 
18 sys.path.append("/home/salonsom/cvn_tensorflow/networks")
19 sys.path.append("/home/salonsom/cvn_tensorflow/callbacks")
20 
21 import se_resnet, resnet, resnetpa, googlenet, my_model
22 import my_callbacks
23 
24 from keras import backend as K
25 K.set_image_data_format('channels_last')
26 
27 '''
28 ****************************************
29 ************** PARAMETERS **************
30 ****************************************
31 '''
32 
33 logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
34 
35 config = configparser.ConfigParser()
36 config.read('config.ini')
37 
38 # random
39 
40 np.random.seed(int(config['random']['seed']))
41 SHUFFLE = ast.literal_eval(config['random']['shuffle'])
42 
43 # images
44 
45 IMAGES_PATH = config['images']['path']
46 VIEWS = int(config['images']['views'])
47 PLANES = int(config['images']['planes'])
48 CELLS = int(config['images']['cells'])
49 STANDARDIZE = ast.literal_eval(config['images']['standardize'])
50 INTERACTION_LABELS = ast.literal_eval(config['images']['interaction_labels'])
51 FILTERED = ast.literal_eval(config['images']['filtered'])
52 
53 INTERACTION_TYPES = ast.literal_eval(config['dataset']['interaction_types'])
54 
55 if(INTERACTION_TYPES):
56 
57  # Interaction types (from 0 to 13 (12))
58 
59  NEUTRINO_LABELS = []
60  N_LABELS = len(Counter(INTERACTION_LABELS.values()))
61 
62 else:
63 
64  # Neutrino types (from 0 to 3)
65 
66  NEUTRINO_LABELS = ast.literal_eval(config['images']['neutrino_labels'])
67  N_LABELS = len(Counter(NEUTRINO_LABELS.values()))
68 
69 # dataset
70 
71 DATASET_PATH = config['dataset']['path']
72 PARTITION_PREFIX = config['dataset']['partition_prefix']
73 LABELS_PREFIX = config['dataset']['labels_prefix']
74 
75 # log
76 
77 LOG_PATH = config['log']['path']
78 LOG_PREFIX = config['log']['prefix']
79 
80 # model
81 
82 CHECKPOINT_PATH = config['model']['checkpoint_path']
83 CHECKPOINT_PREFIX = config['model']['checkpoint_prefix']
84 CHECKPOINT_SAVE_MANY = ast.literal_eval(config['model']['checkpoint_save_many'])
85 CHECKPOINT_SAVE_BEST_ONLY = ast.literal_eval(config['model']['checkpoint_save_best_only'])
86 CHECKPOINT_PERIOD = int(config['model']['checkpoint_period'])
87 PRINT_SUMMARY = ast.literal_eval(config['model']['print_summary'])
88 
89 # train
90 
91 RESUME = ast.literal_eval(config['train']['resume'])
92 LEARNING_RATE = float(config['train']['lr'])
93 MOMENTUM = float(config['train']['momentum'])
94 DECAY = float(config['train']['decay'])
95 TRAIN_BATCH_SIZE = int(config['train']['batch_size'])
96 EPOCHS = int(config['train']['epochs'])
97 EARLY_STOPPING_PATIENCE = int(config['train']['early_stopping_patience'])
98 WEIGHTED_LOSS_FUNCTION = ast.literal_eval(config['train']['weighted_loss_function'])
99 CLASS_WEIGHTS_PREFIX = config['train']['class_weights_prefix']
100 
101 # validation
102 
103 VALIDATION_FRACTION = float(config['validation']['fraction'])
104 VALIDATION_BATCH_SIZE = int(config['validation']['batch_size'])
105 
106 # train params
107 
108 TRAIN_PARAMS = {'planes': PLANES,
109  'cells': CELLS,
110  'views': VIEWS,
111  'batch_size': TRAIN_BATCH_SIZE,
112  'n_labels': N_LABELS,
113  'interaction_labels': INTERACTION_LABELS,
114  'interaction_types': INTERACTION_TYPES,
115  'filtered': FILTERED,
116  'neutrino_labels': NEUTRINO_LABELS,
117  'images_path': IMAGES_PATH,
118  'standardize': STANDARDIZE,
119  'shuffle': SHUFFLE}
120 
121 # validation params
122 
123 VALIDATION_PARAMS = {'planes': PLANES,
124  'cells': CELLS,
125  'views': VIEWS,
126  'batch_size': VALIDATION_BATCH_SIZE,
127  'n_labels': N_LABELS,
128  'interaction_labels': INTERACTION_LABELS,
129  'interaction_types': INTERACTION_TYPES,
130  'filtered': FILTERED,
131  'neutrino_labels': NEUTRINO_LABELS,
132  'images_path': IMAGES_PATH,
133  'standardize': STANDARDIZE,
134  'shuffle': SHUFFLE}
135 
136 
137 '''
138 ****************************************
139 *************** DATASETS ***************
140 ****************************************
141 '''
142 
143 partition = {'train' : [], 'validation' : [], 'test' : []} # Train, validation, and test IDs
144 labels = {} # ID : label
145 
146 # Load datasets
147 
148 logging.info('Loading datasets from serialized files...')
149 
150 partition_file = open(DATASET_PATH + PARTITION_PREFIX + '.p', 'r')
151 partition = pickle.load(partition_file)
152 partition_file.close()
153 
154 labels_file = open(DATASET_PATH + LABELS_PREFIX + '.p', 'r')
155 labels = pickle.load(labels_file)
156 labels_file.close()
157 
158 if WEIGHTED_LOSS_FUNCTION:
159 
160  class_weights_file = open(DATASET_PATH + CLASS_WEIGHTS_PREFIX + '.p', 'r')
161  class_weights = pickle.load(class_weights_file)
162  class_weights_file.close()
163 
164 else:
165 
166  class_weights = None
167 
168 # Print some dataset statistics
169 
170 logging.info('Number of training examples: %d', len(partition['train']))
171 logging.info('Number of validation examples: %d', len(partition['validation']))
172 logging.info('Number of test examples: %d', len(partition['test']))
173 logging.info('Class weights: %s', class_weights)
174 
175 '''
176 ****************************************
177 ************** GENERATORS **************
178 ****************************************
179 '''
180 
181 training_generator = DataGenerator(**TRAIN_PARAMS).generate(labels, partition['train'], True)
182 validation_generator = DataGenerator(**VALIDATION_PARAMS).generate(labels, partition['validation'], True)
183 
184 
185 '''
186 ****************************************
187 *************** CVN MODEL **************
188 ****************************************
189 '''
190 
191 if RESUME:
192 
193  # Resume a previous training
194 
195  logging.info('Loading model from disk...')
196 
197  if CHECKPOINT_SAVE_MANY:
198 
199  # Load the last generated model
200 
201  files = [f for f in os.listdir(CHECKPOINT_PATH) if os.path.isfile(os.path.join(CHECKPOINT_PATH, f))]
202  files.sort(reverse=True)
203 
204  r = re.compile(CHECKPOINT_PREFIX[1:] + '-.*-.*.h5')
205 
206  for fil in files:
207  if r.match(fil) is not None:
208  model = load_model(CHECKPOINT_PATH + '/' + fil)
209 
210  logging.info('Loaded model: %s', CHECKPOINT_PATH + '/' + fil)
211 
212  break
213 
214  else:
215 
216  # Load the model
217 
218  model = load_model(CHECKPOINT_PATH + CHECKPOINT_PREFIX + '.h5')
219 
220  logging.info('Loaded model: %s', CHECKPOINT_PATH + CHECKPOINT_PREFIX + '.h5')
221 
222 else:
223 
224  # Start a new training
225 
226  logging.info('Creating model...')
227 
228  # Input shape: (PLANES x CELLS x VIEWS)
229 
230  input_shape = [PLANES, CELLS, VIEWS]
231 
232  #model = se_resnet.SEResNet(input_shape=input_shape, classes=N_LABELS)
233  #model = resnetpa.ResNetPreAct()
234 
235  model = resnet.ResnetBuilder.build_resnet_18(input_shape, N_LABELS)
236  #model = resnet.ResnetBuilder.build_resnet_34(input_shape, N_LABELS)
237  #model = resnet.ResnetBuilder.build_resnet_50(input_shape, N_LABELS)
238  #model = resnet.ResnetBuilder.build_resnet_101(input_shape, N_LABELS)
239  #model = resnet.ResnetBuilder.build_resnet_152(input_shape, N_LABELS)
240  #model = my_model.my_model(input_shape=input_shape, classes=N_LABELS)
241 
242  # Optimizer: Stochastic Gradient Descent
243 
244  logging.info('Setting optimizer...')
245 
246  opt = optimizers.SGD(lr=LEARNING_RATE, momentum=MOMENTUM, decay=DECAY, nesterov=True)
247 
248  # Compile model
249 
250  logging.info('Compiling model...')
251 
252  model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
253 
254  '''
255  model.compile(loss='categorical_crossentropy',
256  optimizer='adam',
257  metrics=['accuracy'])
258  '''
259 
260 # Print model summary
261 
262 if(PRINT_SUMMARY):
263  model.summary()
264 
265 
266 '''
267 ****************************************
268 *************** CALLBACKS **************
269 ****************************************
270 '''
271 
272 # Checkpointing
273 
274 logging.info('Configuring checkpointing...')
275 
276 # Checkpoint one CVN model only
277 
278 filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX + '.h5'
279 
280 if VALIDATION_FRACTION > 0:
281 
282  # Validation accuracy
283 
284  if CHECKPOINT_SAVE_MANY:
285  filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX + '-{epoch:02d}-{val_acc:.2f}.h5'
286 
287  monitor_acc = 'val_acc'
288  monitor_loss = 'val_loss'
289 
290 else:
291 
292  # Training accuracy
293 
294  if CHECKPOINT_SAVE_MANY:
295  filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX + '-{epoch:02d}-{acc:.2f}.h5'
296 
297  monitor_acc = 'acc'
298  monitor_loss = 'loss'
299 
300 checkpoint = ModelCheckpoint(filepath, monitor=monitor_acc, verbose=1, save_best_only=CHECKPOINT_SAVE_BEST_ONLY, mode='max', period=CHECKPOINT_PERIOD)
301 
302 # Learning rate reducer
303 
304 logging.info('Configuring learning rate reducer...')
305 
306 #lr_reducer = LearningRateScheduler(schedule=lambda epoch,lr: (lr*0.01 if epoch % 2 == 0 else lr))
307 lr_reducer = ReduceLROnPlateau(monitor=monitor_loss, factor=0.1, cooldown=0, patience=3, min_lr=0.5e-6, verbose=1)
308 
309 # Early stopping
310 
311 logging.info('Configuring early stopping...')
312 
313 early_stopping = EarlyStopping(monitor=monitor_acc, patience=EARLY_STOPPING_PATIENCE, mode='auto')
314 
315 # Configuring log file
316 
317 csv_logger = CSVLogger(LOG_PATH + LOG_PREFIX + '.log', append=RESUME)
318 
319 # My callbacks
320 
321 my_callback = my_callbacks.MyCallback()
322 
323 # Callbacks
324 
325 logging.info('Setting callbacks...')
326 
327 callbacks_list = [lr_reducer, checkpoint, early_stopping, csv_logger]
328 #callbacks_list = [lr_reducer, checkpoint, early_stopping, csv_logger, my_callback]
329 
330 
331 '''
332 ****************************************
333 *************** TRAINING ***************
334 ****************************************
335 '''
336 
337 if RESUME:
338 
339  # Resuming training...
340 
341  try:
342 
343  # Open previous log file in order to get the last epoch
344 
345  with open(LOG_PATH + LOG_PREFIX + '.log', 'r') as logfile:
346 
347  # initial_epoch = last_epoch + 1
348 
349  initial_epoch = int(re.search(r'\d+', logfile.read().split('\n')[-2]).group()) + 1
350 
351  except IOError:
352 
353  # Previous log file does not exist. Set initial epoch to 0
354 
355  initial_epoch = 0
356 
357  logging.info('RESUMING TRAINING...')
358 
359 else:
360 
361  # Starting a new training...
362  # initial_epoch must be 0 when starting a training (not resuming it)
363 
364  initial_epoch = 0
365 
366  logging.info('STARTING TRAINING...')
367 
368 if VALIDATION_FRACTION > 0:
369 
370  # Training with validation
371 
372  model.fit_generator(generator = training_generator,
373  steps_per_epoch = len(partition['train'])//TRAIN_BATCH_SIZE,
374  validation_data = validation_generator,
375  validation_steps = len(partition['validation'])//VALIDATION_BATCH_SIZE,
376  epochs = EPOCHS,
377  class_weight = class_weights,
378  callbacks = callbacks_list,
379  initial_epoch = initial_epoch,
380  verbose = 1
381  )
382 
383 else:
384 
385  # Training without validation
386 
387  model.fit_generator(generator = training_generator,
388  steps_per_epoch = len(partition['train'])//TRAIN_BATCH_SIZE,
389  epochs = EPOCHS,
390  class_weight = class_weights,
391  callbacks = callbacks_list,
392  initial_epoch = initial_epoch,
393  verbose = 1
394  )
395 
396 
def build_resnet_18(input_shape, num_outputs)
Definition: resnet.py:241
def load_model(name)
void split(std::string const &s, char c, OutIter dest)
Definition: split.h:35
def open(path, mode='r', buf=-1)