train.py
Go to the documentation of this file.
1 """
2 This is the train module.
3 """
4 
5 __version__ = '1.0'
6 __author__ = 'Saul Alonso-Monsalve'
7 __email__ = "saul.alonso.monsalve@cern.ch"
8 
9 import tensorflow as tf
10 import numpy as np
11 import pickle
12 import configparser
13 import ast
14 import logging
15 import os
16 import sys
17 import re
18 import time
19 
20 # manually specify the GPUs to use
21 os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
22 os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
23 
24 sys.path.append(os.path.join(sys.path[0], 'modules'))
25 sys.path.append(os.path.join(sys.path[0], 'networks'))
26 
27 from keras import backend as K, regularizers, optimizers
28 from keras.utils import multi_gpu_model
29 from keras.models import Model, Sequential, load_model
30 from keras.layers import Input, Dense, Activation, ZeroPadding2D, Dropout, Flatten, BatchNormalization, SeparableConv2D
31 from keras.callbacks import LearningRateScheduler, ReduceLROnPlateau, CSVLogger, ModelCheckpoint, EarlyStopping
32 from keras.layers.convolutional import Conv2D, MaxPooling2D, AveragePooling2D
33 from keras.regularizers import l2
34 from collections import Counter
35 from sklearn.utils import class_weight
36 from data_generator import DataGenerator
37 import networks
38 import my_losses
39 import my_callbacks
40 
41 sess = tf.Session()
42 init = tf.global_variables_initializer()
43 sess.run(init)
44 K.set_session(sess)
45 K.set_image_data_format('channels_last')
46 
47 '''
48 ****************************************
49 ************** PARAMETERS **************
50 ****************************************
51 '''
52 
53 logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
54 
55 config = configparser.ConfigParser()
56 config.read('config/config.ini')
57 
58 # random
59 
60 SEED = int(config['random']['seed'])
61 
62 if SEED == -1:
63  SEED = int(time.time()) # random seed
64 
65 np.random.seed(SEED)
66 SHUFFLE = ast.literal_eval(config['random']['shuffle'])
67 
68 # images
69 
70 IMAGES_PATH = config['images']['path']
71 VIEWS = int(config['images']['views'])
72 PLANES = int(config['images']['planes'])
73 CELLS = int(config['images']['cells'])
74 STANDARDIZE = ast.literal_eval(config['images']['standardize'])
75 
76 # dataset
77 
78 DATASET_PATH = config['dataset']['path']
79 PARTITION_PREFIX = config['dataset']['partition_prefix']
80 LABELS_PREFIX = config['dataset']['labels_prefix']
81 
82 # log
83 
84 LOG_PATH = config['log']['path']
85 LOG_PREFIX = config['log']['prefix']
86 
87 # model
88 
89 ARCHITECTURE = config['model']['architecture']
90 CHECKPOINT_PATH = config['model']['checkpoint_path']
91 CHECKPOINT_PREFIX = config['model']['checkpoint_prefix']
92 CHECKPOINT_SAVE_MANY = ast.literal_eval(config['model']['checkpoint_save_many'])
93 CHECKPOINT_SAVE_BEST_ONLY = ast.literal_eval(config['model']['checkpoint_save_best_only'])
94 CHECKPOINT_PERIOD = int(config['model']['checkpoint_period'])
95 PARALLELIZE = ast.literal_eval(config['model']['parallelize'])
96 GPUS = int(config['model']['gpus'])
97 PRINT_SUMMARY = ast.literal_eval(config['model']['print_summary'])
98 BRANCHES = ast.literal_eval(config['model']['branches'])
99 OUTPUTS = int(config['model']['outputs'])
100 
101 # train
102 
103 RESUME = ast.literal_eval(config['train']['resume'])
104 LEARNING_RATE = float(config['train']['lr'])
105 MOMENTUM = float(config['train']['momentum'])
106 DECAY = float(config['train']['decay'])
107 TRAIN_BATCH_SIZE = int(config['train']['batch_size'])
108 EPOCHS = int(config['train']['epochs'])
109 EARLY_STOPPING_PATIENCE = int(config['train']['early_stopping_patience'])
110 WEIGHTED_LOSS_FUNCTION = ast.literal_eval(config['train']['weighted_loss_function'])
111 CLASS_WEIGHTS_PREFIX = config['train']['class_weights_prefix']
112 MAX_QUEUE_SIZE = int(config['train']['max_queue_size'])
113 
114 # validation
115 
116 VALIDATION_FRACTION = float(config['validation']['fraction'])
117 VALIDATION_BATCH_SIZE = int(config['validation']['batch_size'])
118 
119 # train params
120 
121 TRAIN_PARAMS = {'planes':PLANES,
122  'cells':CELLS,
123  'views':VIEWS,
124  'batch_size':TRAIN_BATCH_SIZE,
125  'branches':BRANCHES,
126  'outputs': OUTPUTS,
127  'images_path':IMAGES_PATH,
128  'standardize':STANDARDIZE,
129  'shuffle':SHUFFLE}
130 
131 # validation params
132 
133 VALIDATION_PARAMS = {'planes':PLANES,
134  'cells':CELLS,
135  'views':VIEWS,
136  'batch_size':VALIDATION_BATCH_SIZE,
137  'branches':BRANCHES,
138  'outputs': OUTPUTS,
139  'images_path':IMAGES_PATH,
140  'standardize':STANDARDIZE,
141  'shuffle':SHUFFLE}
142 
143 
144 '''
145 ****************************************
146 *************** DATASETS ***************
147 ****************************************
148 '''
149 
150 partition = {'train' : [], 'validation' : [], 'test' : []} # Train, validation, and test IDs
151 labels = {} # ID : label
152 
153 # Load datasets
154 
155 logging.info('Loading datasets from serialized files...')
156 
157 with open(DATASET_PATH + PARTITION_PREFIX + '.p', 'r') as partition_file:
158  partition = pickle.load(partition_file)
159 
160 with open(DATASET_PATH + LABELS_PREFIX + '.p', 'r') as labels_file:
161  labels = pickle.load(labels_file)
162 
163 if WEIGHTED_LOSS_FUNCTION:
164  with open(DATASET_PATH + CLASS_WEIGHTS_PREFIX + '.p', 'r') as class_weights_file:
165  class_weights = pickle.load(class_weights_file)
166 else:
167  class_weights = None
168 
169 # Print some dataset statistics
170 
171 logging.info('Number of training examples: %d', len(partition['train']))
172 logging.info('Number of validation examples: %d', len(partition['validation']))
173 logging.info('Number of test examples: %d', len(partition['test']))
174 logging.info('Class weights: %s', class_weights)
175 
176 '''
177 ****************************************
178 ************** GENERATORS **************
179 ****************************************
180 '''
181 
182 training_generator = DataGenerator(**TRAIN_PARAMS).generate(labels, partition['train'], True)
183 validation_generator = DataGenerator(**VALIDATION_PARAMS).generate(labels, partition['validation'], True)
184 
185 
186 '''
187 ****************************************
188 *************** CVN MODEL **************
189 ****************************************
190 '''
191 
192 # Optimizer: Stochastic Gradient Descent
193 
194 logging.info('Setting optimizer...')
195 
196 opt = optimizers.SGD(lr=LEARNING_RATE, momentum=MOMENTUM, decay=DECAY, nesterov=True) # SGD
197 #opt = optimizers.RMSprop(lr=0.045, rho=0.9, decay=0.94, epsilon=None, clipnorm=2.0) # RMSprop (rho = Decay factor, decay = Learning rate decay over each update)
198 #opt = optimizers.Adam(lr=1e-3) # Adam
199 
200 if RESUME:
201  # Resume a previous training
202  logging.info('Loading model from disk...')
203 
204  if CHECKPOINT_SAVE_MANY:
205  # Load the last generated model
206  files = [f for f in os.listdir(CHECKPOINT_PATH) if os.path.isfile(os.path.join(CHECKPOINT_PATH, f))]
207  files.sort(reverse=True)
208 
209  r = re.compile(CHECKPOINT_PREFIX[1:] + '-.*-.*.h5')
210 
211  for fil in files:
212  if r.match(fil) is not None:
213  filename = CHECKPOINT_PATH + '/' + fil
214  sequential_model = load_model(filename,
215  custom_objects={'tf':tf,
216  'masked_loss':my_losses.masked_loss,
217  'multitask_loss':my_losses.multitask_loss,
218  'masked_loss_binary':my_losses.masked_loss_binary,
219  'masked_loss_categorical':my_losses.masked_loss_categorical})
220  logging.info('Loaded model: %s', CHECKPOINT_PATH + '/' + fil)
221  break
222  else:
223  # Load the model
224  filename = CHECKPOINT_PATH + CHECKPOINT_PREFIX + '.h5'
225  sequential_model = load_model(filename, custom_objects={'masked_loss':my_losses.masked_loss, 'my_losses':my_losses.my_losses})
226 
227  logging.info('Loaded model: %s', CHECKPOINT_PATH + CHECKPOINT_PREFIX + '.h5')
228 else:
229  # Start a new training
230  logging.info('Creating model...')
231 
232  # Input shape
233  if BRANCHES:
234  input_shape = [PLANES, CELLS, 1] # Input: VIEWS x [PLANES x CELLS x 1]
235  else:
236  input_shape = [PLANES, CELLS, VIEWS] # Input: [PLANES x CELLS x VIEWS]
237 
238  '''
239  Create model. Options (argument 'network'):
240 
241  Total params Trainable params Non-trainable params
242 
243  'xception' -> Xception 20,888,117 20,833,589 54,528
244  'vgg16' -> VGG-16 503,412,557 503,412,557 0
245  'vgg19' -> VGG-19 508,722,253 508,722,253 0
246  'resnet18' -> ResNet-18 11,193,997 11,186,189 7,808
247  'resnet34' -> ResNet-34 21,313,293 21,298,061 15,232
248  'resnet50' -> ResNet-50 23,694,221 23,641,101 53,120
249  'resnet101' -> ResNet-101 42,669,453 42,571,789 97,664
250  'resnet152' -> ResNet-152 58,382,221 58,238,477 143,744
251  'inceptionv3' -> Inception-v3 21,829,421 21,794,989 34,432
252  'inceptionv4' -> Inception-v4 41,194,381 41,131,213 63,168
253  'inceptionresnetv2' -> Inception-ResNet-v2 54,356,717 54,296,173 60,544
254  'resnext' -> ResNeXt 89,712,576 89,612,096 100,480
255  'seresnet18' -> SE-ResNet-18 11,276,224 11,268,416 7,808
256  'seresnet34' -> SE-ResNet-34 21,461,952 21,446,720 15,232
257  'seresnet50' -> SE-ResNet-50 26,087,360 26,041,920 45,440
258  'seresnet101' -> SE-ResNet-101 47,988,672 47,887,936 100,736
259  'seresnet154' -> SE-ResNet-154 64,884,672 64,740,928 143,744
260  'seresnetsaul' -> SE-ResNet-Saul 22,072,768 22,055,744 17,024
261  'seinceptionv3' -> SE-Inception-v3 23,480,365 23,445,933 34,432
262  'seinceptionresnetv2' -> SE-Inception-ResNet-v2 64,094,445 64,033,901 60,544
263  'seresnext' -> SE-ResNeXt 97,869,632 97,869,632 100,480
264  'mobilenet' -> MobileNet 3,242,189 3,220,301 21,888
265  'densenet121' -> DenseNet-121 7,050,829 6,967,181 83,648
266  'densenet169' -> DenseNet-169 12,664,525 12,506,125 158,400
267  'densenet201' -> DenseNet-201 18,346,957 18,117,901 229,056
268  other -> Custom model
269  '''
270 
271  aux_model = networks.create_model(network=ARCHITECTURE, input_shape=input_shape)
272  aux_model.layers.pop() # remove the last layer of the model
273 
274  weight_decay = 1e-4
275 
276  x = [None]*OUTPUTS
277 
278  if OUTPUTS == 1:
279  x[0] = Dense(13, use_bias=False, kernel_regularizer=l2(weight_decay),
280  activation='softmax', name='categories')(aux_model.layers[-1].output)
281  elif OUTPUTS == 5:
282  x[0] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
283  activation='softmax', name='flavour')(aux_model.layers[-1].output)
284  x[1] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
285  activation='softmax', name='protons')(aux_model.layers[-1].output)
286  x[2] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
287  activation='softmax', name='pions')(aux_model.layers[-1].output)
288  x[3] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
289  activation='softmax', name='pizeros')(aux_model.layers[-1].output)
290  x[4] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
291  activation='softmax', name='neutrons')(aux_model.layers[-1].output)
292  else:
293  x[0] = Dense(1, use_bias=False, kernel_regularizer=l2(weight_decay),
294  activation='sigmoid', name='is_antineutrino')(aux_model.layers[-1].output)
295  x[1] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
296  activation='softmax', name='flavour')(aux_model.layers[-1].output)
297  x[2] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
298  activation='softmax', name='interaction')(aux_model.layers[-1].output)
299  x[3] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
300  activation='softmax', name='protons')(aux_model.layers[-1].output)
301  x[4] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
302  activation='softmax', name='pions')(aux_model.layers[-1].output)
303  x[5] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
304  activation='softmax', name='pizeros')(aux_model.layers[-1].output)
305  x[6] = Dense(4, use_bias=False, kernel_regularizer=l2(weight_decay),
306  activation='softmax', name='neutrons')(aux_model.layers[-1].output)
307 
308  sequential_model = Model(inputs=aux_model.inputs, outputs=x, name='resnext')
309 
310 if PARALLELIZE:
311  # Parallelize the model (use all the available GPUs)
312  try:
313  model = multi_gpu_model(sequential_model, gpus=GPUS, cpu_relocation=True)
314  num_outputs = len(sequential_model.output_names)
315  model.layers[-(num_outputs+1)].set_weights(sequential_model.get_weights()) # set the same weights as the sequential model
316  logging.info('Training using %d GPUs...', GPUS)
317  except:
318  model = sequential_model
319  logging.info('Training using single GPU or CPU...')
320 else:
321  model = sequential_model
322 
323 if not model._is_compiled:
324  # Compile model
325  logging.info('Compiling model...')
326 
327  if OUTPUTS == 1:
328  model_loss = {'categories':my_losses.masked_loss_categorical}
329  elif OUTPUTS == 5:
330  model_loss = {'flavour':my_losses.masked_loss_categorical,
331  'protons':my_losses.masked_loss_categorical,
332  'pions':my_losses.masked_loss_categorical,
333  'pizeros':my_losses.masked_loss_categorical,
334  'neutrons':my_losses.masked_loss_categorical}
335  else:
336  model_loss = {'is_antineutrino':my_losses.masked_loss_binary,
337  'flavour':my_losses.masked_loss_categorical,
338  'interaction':my_losses.masked_loss_categorical,
339  'protons':my_losses.masked_loss_categorical,
340  'pions':my_losses.masked_loss_categorical,
341  'pizeros':my_losses.masked_loss_categorical,
342  'neutrons':my_losses.masked_loss_categorical}
343 
344  #model.compile(loss=my_losses.masked_loss, optimizer=opt, metrics=['accuracy'])
345  model.compile(#loss={'neutrino': my_losses.masked_loss_binary, 'flavour': my_losses.masked_loss_categorical, 'interaction': my_losses.masked_loss_categorical},
346  loss=model_loss,
347  #loss_weights={'neutrino':0.25, 'flavour':1.0, 'interaction': 0.5},
348  #loss_weights={'neutrino':0.33, 'flavour':0.33, 'interaction':0.33},
349  optimizer=opt,
350  metrics=['accuracy'])
351  #model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
352 
353 if(PRINT_SUMMARY):
354  # Print model summary
355  if PARALLELIZE:
356  sequential_model.summary()
357  model.summary()
358 
359 '''
360 ****************************************
361 *************** CALLBACKS **************
362 ****************************************
363 '''
364 
365 # Checkpointing
366 
367 logging.info('Configuring checkpointing...')
368 
369 # Checkpoint one CVN model only
370 
371 filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX + '.h5'
372 
373 if VALIDATION_FRACTION > 0:
374  if OUTPUTS == 1:
375  # Validation accuracy
376  if CHECKPOINT_SAVE_MANY:
377  filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX + '-{epoch:02d}-{val_acc:.2f}.h5'
378  monitor_acc = 'val_acc'
379  monitor_loss = 'val_loss'
380  else:
381  # Validation accuracy
382  if CHECKPOINT_SAVE_MANY:
383  filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX + '-{epoch:02d}-{val_flavour_acc:.2f}.h5'
384  monitor_acc = 'val_flavour_acc'
385  monitor_loss = 'val_flavour_loss'
386 else:
387  # Training accuracy
388  if CHECKPOINT_SAVE_MANY:
389  filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX + '-{epoch:02d}-{acc:.2f}.h5'
390  monitor_acc = 'acc'
391  monitor_loss = 'loss'
392 
393 if PARALLELIZE:
394  #checkpoint = my_callbacks.MultiGPUCheckpointCallback(filepath, sequential_model, monitor=monitor_acc, verbose=1, save_best_only=CHECKPOINT_SAVE_BEST_ONLY, mode='max', period=CHECKPOINT_PERIOD)
395  checkpoint = my_callbacks.ModelCheckpointDetached(filepath, monitor=monitor_acc, verbose=1, save_best_only=CHECKPOINT_SAVE_BEST_ONLY, save_weights_only=False, mode='max', period=CHECKPOINT_PERIOD)
396 else:
397  checkpoint = ModelCheckpoint(filepath, monitor=monitor_acc, verbose=1, save_best_only=CHECKPOINT_SAVE_BEST_ONLY, mode='max', period=CHECKPOINT_PERIOD)
398 
399 # Learning rate reducer
400 
401 logging.info('Configuring learning rate reducer...')
402 
403 #lr_reducer = LearningRateScheduler(schedule=lambda epoch,lr: (lr*0.01 if epoch % 2 == 0 else lr))
404 #lr_reducer = ReduceLROnPlateau(monitor=monitor_loss, factor=0.1, cooldown=0, patience=2, min_lr=0.5e-6, verbose=1)
405 lr_reducer = ReduceLROnPlateau(monitor=monitor_acc, mode='max', factor=0.1, cooldown=0, patience=10, min_lr=0.5e-6, verbose=1)
406 
407 # Early stopping
408 
409 logging.info('Configuring early stopping...')
410 
411 early_stopping = EarlyStopping(monitor=monitor_acc, patience=EARLY_STOPPING_PATIENCE, mode='auto')
412 
413 # Configuring log file
414 
415 csv_logger = CSVLogger(LOG_PATH + LOG_PREFIX + '.log', append=RESUME)
416 
417 # My callbacks
418 
419 my_callback = my_callbacks.MyCallback()
420 #my_callback = my_callbacks.InceptionV4Callback()
421 #my_callback = my_callbacks.IterationsCallback(validation_generator=validation_generator, validation_steps=len(partition['validation'])//VALIDATION_BATCH_SIZE)
422 
423 # Callbacks
424 
425 logging.info('Setting callbacks...')
426 
427 #callbacks_list = [checkpoint, csv_logger]
428 #callbacks_list = [checkpoint, csv_logger, my_callback]
429 #callbacks_list = [lr_reducer, checkpoint, early_stopping, csv_logger]
430 callbacks_list = [lr_reducer, checkpoint, early_stopping, csv_logger, my_callback]
431 
432 
433 '''
434 ****************************************
435 *************** TRAINING ***************
436 ****************************************
437 '''
438 
439 if RESUME:
440  # Resuming training...
441  try:
442  # Open previous log file in order to get the last epoch
443  with open(LOG_PATH + LOG_PREFIX + '.log', 'r') as logfile:
444  # initial_epoch = last_epoch + 1
445  initial_epoch = int(re.search(r'\d+', logfile.read().split('\n')[-2]).group()) + 1
446 
447  except IOError:
448  # Previous log file does not exist. Set initial epoch to 0
449  initial_epoch = 0
450  logging.info('RESUMING TRAINING...')
451 else:
452  # Starting a new training...
453  # initial_epoch must be 0 when starting a training (not resuming it)
454  initial_epoch = 0
455  logging.info('STARTING TRAINING...')
456 
457 if VALIDATION_FRACTION > 0:
458  # Training with validation
459  validation_data = validation_generator
460  validation_steps = len(partition['validation'])//VALIDATION_BATCH_SIZE
461 else:
462  # Training without validation
463  validation_data = None
464  validation_steps = None
465 
466 # TRAINING
467 
468 model.fit_generator(generator=training_generator,
469  steps_per_epoch=len(partition['train'])//TRAIN_BATCH_SIZE,
470  validation_data=validation_data,
471  validation_steps=validation_steps,
472  epochs=EPOCHS,
473  class_weight=class_weights,
474  callbacks=callbacks_list,
475  initial_epoch=initial_epoch,
476  max_queue_size=MAX_QUEUE_SIZE,
477  verbose=1,
478  use_multiprocessing=False,
479  workers=1
480  )
481 
int open(const char *, int)
Opens a file descriptor.
def load_model(name)
def create_model(network='resnet50', num_classes=13, input_shape=[500, transfer_learning=None)
Definition: networks.py:411
void split(std::string const &s, char c, OutIter dest)
Definition: split.h:35
if(!yymsg) yymsg