2 This is the train module. 6 __author__ =
'Saul Alonso-Monsalve' 7 __email__ =
"saul.alonso.monsalve@cern.ch" 9 import tensorflow
as tf
21 os.environ[
"CUDA_DEVICE_ORDER"]=
"PCI_BUS_ID" 22 os.environ[
"CUDA_VISIBLE_DEVICES"]=
"0,1,2,3" 24 sys.path.append(os.path.join(sys.path[0],
'modules'))
25 sys.path.append(os.path.join(sys.path[0],
'networks'))
27 from keras
import backend
as K, regularizers, optimizers
28 from keras.utils
import multi_gpu_model
29 from keras.models
import Model, Sequential, load_model
30 from keras.layers
import Input, Dense, Activation, ZeroPadding2D, Dropout, Flatten, BatchNormalization, SeparableConv2D
31 from keras.callbacks
import LearningRateScheduler, ReduceLROnPlateau, CSVLogger, ModelCheckpoint, EarlyStopping
32 from keras.layers.convolutional
import Conv2D, MaxPooling2D, AveragePooling2D
33 from keras.regularizers
import l2
34 from collections
import Counter
35 from sklearn.utils
import class_weight
36 from data_generator
import DataGenerator
42 init = tf.global_variables_initializer()
45 K.set_image_data_format(
'channels_last')
48 **************************************** 49 ************** PARAMETERS ************** 50 **************************************** 53 logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
55 config = configparser.ConfigParser()
56 config.read(
'config/config.ini')
60 SEED =
int(config[
'random'][
'seed'])
63 SEED =
int(time.time())
66 SHUFFLE = ast.literal_eval(config[
'random'][
'shuffle'])
70 IMAGES_PATH = config[
'images'][
'path']
71 VIEWS =
int(config[
'images'][
'views'])
72 PLANES =
int(config[
'images'][
'planes'])
73 CELLS =
int(config[
'images'][
'cells'])
74 STANDARDIZE = ast.literal_eval(config[
'images'][
'standardize'])
78 DATASET_PATH = config[
'dataset'][
'path']
79 PARTITION_PREFIX = config[
'dataset'][
'partition_prefix']
80 LABELS_PREFIX = config[
'dataset'][
'labels_prefix']
84 LOG_PATH = config[
'log'][
'path']
85 LOG_PREFIX = config[
'log'][
'prefix']
89 ARCHITECTURE = config[
'model'][
'architecture']
90 CHECKPOINT_PATH = config[
'model'][
'checkpoint_path']
91 CHECKPOINT_PREFIX = config[
'model'][
'checkpoint_prefix']
92 CHECKPOINT_SAVE_MANY = ast.literal_eval(config[
'model'][
'checkpoint_save_many'])
93 CHECKPOINT_SAVE_BEST_ONLY = ast.literal_eval(config[
'model'][
'checkpoint_save_best_only'])
94 CHECKPOINT_PERIOD =
int(config[
'model'][
'checkpoint_period'])
95 PARALLELIZE = ast.literal_eval(config[
'model'][
'parallelize'])
96 GPUS =
int(config[
'model'][
'gpus'])
97 PRINT_SUMMARY = ast.literal_eval(config[
'model'][
'print_summary'])
98 BRANCHES = ast.literal_eval(config[
'model'][
'branches'])
99 OUTPUTS =
int(config[
'model'][
'outputs'])
103 RESUME = ast.literal_eval(config[
'train'][
'resume'])
104 LEARNING_RATE =
float(config[
'train'][
'lr'])
105 MOMENTUM =
float(config[
'train'][
'momentum'])
106 DECAY =
float(config[
'train'][
'decay'])
107 TRAIN_BATCH_SIZE =
int(config[
'train'][
'batch_size'])
108 EPOCHS =
int(config[
'train'][
'epochs'])
109 EARLY_STOPPING_PATIENCE =
int(config[
'train'][
'early_stopping_patience'])
110 WEIGHTED_LOSS_FUNCTION = ast.literal_eval(config[
'train'][
'weighted_loss_function'])
111 CLASS_WEIGHTS_PREFIX = config[
'train'][
'class_weights_prefix']
112 MAX_QUEUE_SIZE =
int(config[
'train'][
'max_queue_size'])
116 VALIDATION_FRACTION =
float(config[
'validation'][
'fraction'])
117 VALIDATION_BATCH_SIZE =
int(config[
'validation'][
'batch_size'])
121 TRAIN_PARAMS = {
'planes':PLANES,
124 'batch_size':TRAIN_BATCH_SIZE,
127 'images_path':IMAGES_PATH,
128 'standardize':STANDARDIZE,
133 VALIDATION_PARAMS = {
'planes':PLANES,
136 'batch_size':VALIDATION_BATCH_SIZE,
139 'images_path':IMAGES_PATH,
140 'standardize':STANDARDIZE,
145 **************************************** 146 *************** DATASETS *************** 147 **************************************** 150 partition = {
'train' : [],
'validation' : [],
'test' : []}
155 logging.info(
'Loading datasets from serialized files...')
157 with
open(DATASET_PATH + PARTITION_PREFIX +
'.p',
'r') as partition_file: 158 partition = pickle.load(partition_file) 160 with open(DATASET_PATH + LABELS_PREFIX + '.p',
'r') as labels_file: 161 labels = pickle.load(labels_file) 163 if WEIGHTED_LOSS_FUNCTION:
164 with
open(DATASET_PATH + CLASS_WEIGHTS_PREFIX +
'.p',
'r') as class_weights_file: 165 class_weights = pickle.load(class_weights_file) 171 logging.info(
'Number of training examples: %d', len(partition[
'train']))
172 logging.info(
'Number of validation examples: %d', len(partition[
'validation']))
173 logging.info(
'Number of test examples: %d', len(partition[
'test']))
174 logging.info(
'Class weights: %s', class_weights)
177 **************************************** 178 ************** GENERATORS ************** 179 **************************************** 182 training_generator =
DataGenerator(**TRAIN_PARAMS).generate(labels, partition[
'train'],
True)
183 validation_generator =
DataGenerator(**VALIDATION_PARAMS).generate(labels, partition[
'validation'],
True)
187 **************************************** 188 *************** CVN MODEL ************** 189 **************************************** 194 logging.info(
'Setting optimizer...')
196 opt = optimizers.SGD(lr=LEARNING_RATE, momentum=MOMENTUM, decay=DECAY, nesterov=
True)
202 logging.info(
'Loading model from disk...')
204 if CHECKPOINT_SAVE_MANY:
206 files = [f
for f
in os.listdir(CHECKPOINT_PATH)
if os.path.isfile(os.path.join(CHECKPOINT_PATH, f))]
207 files.sort(reverse=
True)
209 r = re.compile(CHECKPOINT_PREFIX[1:] +
'-.*-.*.h5')
212 if r.match(fil)
is not None:
213 filename = CHECKPOINT_PATH +
'/' + fil
215 custom_objects={
'tf':tf,
216 'masked_loss':my_losses.masked_loss,
217 'multitask_loss':my_losses.multitask_loss,
218 'masked_loss_binary':my_losses.masked_loss_binary,
219 'masked_loss_categorical':my_losses.masked_loss_categorical})
220 logging.info(
'Loaded model: %s', CHECKPOINT_PATH +
'/' + fil)
224 filename = CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'.h5' 225 sequential_model =
load_model(filename, custom_objects={
'masked_loss':my_losses.masked_loss,
'my_losses':my_losses.my_losses})
227 logging.info(
'Loaded model: %s', CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'.h5')
230 logging.info(
'Creating model...')
234 input_shape = [PLANES, CELLS, 1]
236 input_shape = [PLANES, CELLS, VIEWS]
239 Create model. Options (argument 'network'): 241 Total params Trainable params Non-trainable params 243 'xception' -> Xception 20,888,117 20,833,589 54,528 244 'vgg16' -> VGG-16 503,412,557 503,412,557 0 245 'vgg19' -> VGG-19 508,722,253 508,722,253 0 246 'resnet18' -> ResNet-18 11,193,997 11,186,189 7,808 247 'resnet34' -> ResNet-34 21,313,293 21,298,061 15,232 248 'resnet50' -> ResNet-50 23,694,221 23,641,101 53,120 249 'resnet101' -> ResNet-101 42,669,453 42,571,789 97,664 250 'resnet152' -> ResNet-152 58,382,221 58,238,477 143,744 251 'inceptionv3' -> Inception-v3 21,829,421 21,794,989 34,432 252 'inceptionv4' -> Inception-v4 41,194,381 41,131,213 63,168 253 'inceptionresnetv2' -> Inception-ResNet-v2 54,356,717 54,296,173 60,544 254 'resnext' -> ResNeXt 89,712,576 89,612,096 100,480 255 'seresnet18' -> SE-ResNet-18 11,276,224 11,268,416 7,808 256 'seresnet34' -> SE-ResNet-34 21,461,952 21,446,720 15,232 257 'seresnet50' -> SE-ResNet-50 26,087,360 26,041,920 45,440 258 'seresnet101' -> SE-ResNet-101 47,988,672 47,887,936 100,736 259 'seresnet154' -> SE-ResNet-154 64,884,672 64,740,928 143,744 260 'seresnetsaul' -> SE-ResNet-Saul 22,072,768 22,055,744 17,024 261 'seinceptionv3' -> SE-Inception-v3 23,480,365 23,445,933 34,432 262 'seinceptionresnetv2' -> SE-Inception-ResNet-v2 64,094,445 64,033,901 60,544 263 'seresnext' -> SE-ResNeXt 97,869,632 97,869,632 100,480 264 'mobilenet' -> MobileNet 3,242,189 3,220,301 21,888 265 'densenet121' -> DenseNet-121 7,050,829 6,967,181 83,648 266 'densenet169' -> DenseNet-169 12,664,525 12,506,125 158,400 267 'densenet201' -> DenseNet-201 18,346,957 18,117,901 229,056 268 other -> Custom model 272 aux_model.layers.pop()
279 x[0] = Dense(13, use_bias=
False, kernel_regularizer=l2(weight_decay),
280 activation=
'softmax', name=
'categories')(aux_model.layers[-1].output)
282 x[0] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
283 activation=
'softmax', name=
'flavour')(aux_model.layers[-1].output)
284 x[1] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
285 activation=
'softmax', name=
'protons')(aux_model.layers[-1].output)
286 x[2] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
287 activation=
'softmax', name=
'pions')(aux_model.layers[-1].output)
288 x[3] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
289 activation=
'softmax', name=
'pizeros')(aux_model.layers[-1].output)
290 x[4] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
291 activation=
'softmax', name=
'neutrons')(aux_model.layers[-1].output)
293 x[0] = Dense(1, use_bias=
False, kernel_regularizer=l2(weight_decay),
294 activation=
'sigmoid', name=
'is_antineutrino')(aux_model.layers[-1].output)
295 x[1] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
296 activation=
'softmax', name=
'flavour')(aux_model.layers[-1].output)
297 x[2] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
298 activation=
'softmax', name=
'interaction')(aux_model.layers[-1].output)
299 x[3] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
300 activation=
'softmax', name=
'protons')(aux_model.layers[-1].output)
301 x[4] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
302 activation=
'softmax', name=
'pions')(aux_model.layers[-1].output)
303 x[5] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
304 activation=
'softmax', name=
'pizeros')(aux_model.layers[-1].output)
305 x[6] = Dense(4, use_bias=
False, kernel_regularizer=l2(weight_decay),
306 activation=
'softmax', name=
'neutrons')(aux_model.layers[-1].output)
308 sequential_model = Model(inputs=aux_model.inputs, outputs=x, name=
'resnext')
313 model = multi_gpu_model(sequential_model, gpus=GPUS, cpu_relocation=
True)
314 num_outputs = len(sequential_model.output_names)
315 model.layers[-(num_outputs+1)].set_weights(sequential_model.get_weights())
316 logging.info(
'Training using %d GPUs...', GPUS)
318 model = sequential_model
319 logging.info(
'Training using single GPU or CPU...')
321 model = sequential_model
323 if not model._is_compiled:
325 logging.info(
'Compiling model...')
328 model_loss = {
'categories':my_losses.masked_loss_categorical}
330 model_loss = {
'flavour':my_losses.masked_loss_categorical,
331 'protons':my_losses.masked_loss_categorical,
332 'pions':my_losses.masked_loss_categorical,
333 'pizeros':my_losses.masked_loss_categorical,
334 'neutrons':my_losses.masked_loss_categorical}
336 model_loss = {
'is_antineutrino':my_losses.masked_loss_binary,
337 'flavour':my_losses.masked_loss_categorical,
338 'interaction':my_losses.masked_loss_categorical,
339 'protons':my_losses.masked_loss_categorical,
340 'pions':my_losses.masked_loss_categorical,
341 'pizeros':my_losses.masked_loss_categorical,
342 'neutrons':my_losses.masked_loss_categorical}
350 metrics=[
'accuracy'])
356 sequential_model.summary()
360 **************************************** 361 *************** CALLBACKS ************** 362 **************************************** 367 logging.info(
'Configuring checkpointing...')
371 filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'.h5' 373 if VALIDATION_FRACTION > 0:
376 if CHECKPOINT_SAVE_MANY:
377 filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'-{epoch:02d}-{val_acc:.2f}.h5' 378 monitor_acc =
'val_acc' 379 monitor_loss =
'val_loss' 382 if CHECKPOINT_SAVE_MANY:
383 filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'-{epoch:02d}-{val_flavour_acc:.2f}.h5' 384 monitor_acc =
'val_flavour_acc' 385 monitor_loss =
'val_flavour_loss' 388 if CHECKPOINT_SAVE_MANY:
389 filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'-{epoch:02d}-{acc:.2f}.h5' 391 monitor_loss =
'loss' 395 checkpoint =
my_callbacks.ModelCheckpointDetached(filepath, monitor=monitor_acc, verbose=1, save_best_only=CHECKPOINT_SAVE_BEST_ONLY, save_weights_only=
False, mode=
'max', period=CHECKPOINT_PERIOD)
397 checkpoint = ModelCheckpoint(filepath, monitor=monitor_acc, verbose=1, save_best_only=CHECKPOINT_SAVE_BEST_ONLY, mode=
'max', period=CHECKPOINT_PERIOD)
401 logging.info(
'Configuring learning rate reducer...')
405 lr_reducer = ReduceLROnPlateau(monitor=monitor_acc, mode=
'max', factor=0.1, cooldown=0, patience=10, min_lr=0.5e-6, verbose=1)
409 logging.info(
'Configuring early stopping...')
411 early_stopping = EarlyStopping(monitor=monitor_acc, patience=EARLY_STOPPING_PATIENCE, mode=
'auto')
415 csv_logger = CSVLogger(LOG_PATH + LOG_PREFIX +
'.log', append=RESUME)
425 logging.info(
'Setting callbacks...')
430 callbacks_list = [lr_reducer, checkpoint, early_stopping, csv_logger, my_callback]
434 **************************************** 435 *************** TRAINING *************** 436 **************************************** 443 with
open(LOG_PATH + LOG_PREFIX +
'.log',
'r') as logfile: 445 initial_epoch =
int(re.search(
r'\d+', logfile.read().
split(
'\n')[-2]).group()) + 1
450 logging.info(
'RESUMING TRAINING...')
455 logging.info(
'STARTING TRAINING...')
457 if VALIDATION_FRACTION > 0:
459 validation_data = validation_generator
460 validation_steps = len(partition[
'validation'])//VALIDATION_BATCH_SIZE
463 validation_data =
None 464 validation_steps =
None 468 model.fit_generator(generator=training_generator,
469 steps_per_epoch=len(partition[
'train'])//TRAIN_BATCH_SIZE,
470 validation_data=validation_data,
471 validation_steps=validation_steps,
473 class_weight=class_weights,
474 callbacks=callbacks_list,
475 initial_epoch=initial_epoch,
476 max_queue_size=MAX_QUEUE_SIZE,
478 use_multiprocessing=
False,
int open(const char *, int)
Opens a file descriptor.
def create_model(network='resnet50', num_classes=13, input_shape=[500, transfer_learning=None)
void split(std::string const &s, char c, OutIter dest)