9 from sklearn.utils
import class_weight
10 from keras.models
import Model, Sequential, load_model
11 from keras.layers
import Input, Dense, Activation, ZeroPadding2D, Dropout, Flatten, BatchNormalization, SeparableConv2D
12 from keras
import regularizers, optimizers
13 from keras.layers.convolutional
import Conv2D, MaxPooling2D, AveragePooling2D
14 from keras.callbacks
import LearningRateScheduler, ReduceLROnPlateau, CSVLogger, ModelCheckpoint, EarlyStopping
15 from data_generator
import DataGenerator
16 from collections
import Counter
18 sys.path.append(
"/home/salonsom/cvn_tensorflow/networks")
19 sys.path.append(
"/home/salonsom/cvn_tensorflow/callbacks")
21 import se_resnet, resnet, resnetpa, googlenet, my_model
24 from keras
import backend
as K
25 K.set_image_data_format(
'channels_last')
28 **************************************** 29 ************** PARAMETERS ************** 30 **************************************** 33 logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
35 config = configparser.ConfigParser()
36 config.read(
'config.ini')
40 np.random.seed(
int(config[
'random'][
'seed']))
41 SHUFFLE = ast.literal_eval(config[
'random'][
'shuffle'])
45 IMAGES_PATH = config[
'images'][
'path']
46 VIEWS =
int(config[
'images'][
'views'])
47 PLANES =
int(config[
'images'][
'planes'])
48 CELLS =
int(config[
'images'][
'cells'])
49 STANDARDIZE = ast.literal_eval(config[
'images'][
'standardize'])
50 INTERACTION_LABELS = ast.literal_eval(config[
'images'][
'interaction_labels'])
51 FILTERED = ast.literal_eval(config[
'images'][
'filtered'])
53 INTERACTION_TYPES = ast.literal_eval(config[
'dataset'][
'interaction_types'])
55 if(INTERACTION_TYPES):
60 N_LABELS = len(Counter(INTERACTION_LABELS.values()))
66 NEUTRINO_LABELS = ast.literal_eval(config[
'images'][
'neutrino_labels'])
67 N_LABELS = len(Counter(NEUTRINO_LABELS.values()))
71 DATASET_PATH = config[
'dataset'][
'path']
72 PARTITION_PREFIX = config[
'dataset'][
'partition_prefix']
73 LABELS_PREFIX = config[
'dataset'][
'labels_prefix']
77 LOG_PATH = config[
'log'][
'path']
78 LOG_PREFIX = config[
'log'][
'prefix']
82 CHECKPOINT_PATH = config[
'model'][
'checkpoint_path']
83 CHECKPOINT_PREFIX = config[
'model'][
'checkpoint_prefix']
84 CHECKPOINT_SAVE_MANY = ast.literal_eval(config[
'model'][
'checkpoint_save_many'])
85 CHECKPOINT_SAVE_BEST_ONLY = ast.literal_eval(config[
'model'][
'checkpoint_save_best_only'])
86 CHECKPOINT_PERIOD =
int(config[
'model'][
'checkpoint_period'])
87 PRINT_SUMMARY = ast.literal_eval(config[
'model'][
'print_summary'])
91 RESUME = ast.literal_eval(config[
'train'][
'resume'])
92 LEARNING_RATE = float(config[
'train'][
'lr'])
93 MOMENTUM = float(config[
'train'][
'momentum'])
94 DECAY = float(config[
'train'][
'decay'])
95 TRAIN_BATCH_SIZE =
int(config[
'train'][
'batch_size'])
96 EPOCHS =
int(config[
'train'][
'epochs'])
97 EARLY_STOPPING_PATIENCE =
int(config[
'train'][
'early_stopping_patience'])
98 WEIGHTED_LOSS_FUNCTION = ast.literal_eval(config[
'train'][
'weighted_loss_function'])
99 CLASS_WEIGHTS_PREFIX = config[
'train'][
'class_weights_prefix']
103 VALIDATION_FRACTION = float(config[
'validation'][
'fraction'])
104 VALIDATION_BATCH_SIZE =
int(config[
'validation'][
'batch_size'])
108 TRAIN_PARAMS = {
'planes': PLANES,
111 'batch_size': TRAIN_BATCH_SIZE,
112 'n_labels': N_LABELS,
113 'interaction_labels': INTERACTION_LABELS,
114 'interaction_types': INTERACTION_TYPES,
115 'filtered': FILTERED,
116 'neutrino_labels': NEUTRINO_LABELS,
117 'images_path': IMAGES_PATH,
118 'standardize': STANDARDIZE,
123 VALIDATION_PARAMS = {
'planes': PLANES,
126 'batch_size': VALIDATION_BATCH_SIZE,
127 'n_labels': N_LABELS,
128 'interaction_labels': INTERACTION_LABELS,
129 'interaction_types': INTERACTION_TYPES,
130 'filtered': FILTERED,
131 'neutrino_labels': NEUTRINO_LABELS,
132 'images_path': IMAGES_PATH,
133 'standardize': STANDARDIZE,
138 **************************************** 139 *************** DATASETS *************** 140 **************************************** 143 partition = {
'train' : [],
'validation' : [],
'test' : []}
148 logging.info(
'Loading datasets from serialized files...')
150 partition_file =
open(DATASET_PATH + PARTITION_PREFIX +
'.p',
'r') 151 partition = pickle.load(partition_file) 152 partition_file.close() 154 labels_file = open(DATASET_PATH + LABELS_PREFIX + '.p',
'r') 155 labels = pickle.load(labels_file) 158 if WEIGHTED_LOSS_FUNCTION:
160 class_weights_file =
open(DATASET_PATH + CLASS_WEIGHTS_PREFIX +
'.p',
'r') 161 class_weights = pickle.load(class_weights_file) 162 class_weights_file.close() 170 logging.info(
'Number of training examples: %d', len(partition[
'train']))
171 logging.info(
'Number of validation examples: %d', len(partition[
'validation']))
172 logging.info(
'Number of test examples: %d', len(partition[
'test']))
173 logging.info(
'Class weights: %s', class_weights)
176 **************************************** 177 ************** GENERATORS ************** 178 **************************************** 181 training_generator =
DataGenerator(**TRAIN_PARAMS).generate(labels, partition[
'train'],
True)
182 validation_generator =
DataGenerator(**VALIDATION_PARAMS).generate(labels, partition[
'validation'],
True)
186 **************************************** 187 *************** CVN MODEL ************** 188 **************************************** 195 logging.info(
'Loading model from disk...')
197 if CHECKPOINT_SAVE_MANY:
201 files = [f
for f
in os.listdir(CHECKPOINT_PATH)
if os.path.isfile(os.path.join(CHECKPOINT_PATH, f))]
202 files.sort(reverse=
True)
204 r = re.compile(CHECKPOINT_PREFIX[1:] +
'-.*-.*.h5')
207 if r.match(fil)
is not None:
210 logging.info(
'Loaded model: %s', CHECKPOINT_PATH +
'/' + fil)
218 model =
load_model(CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'.h5')
220 logging.info(
'Loaded model: %s', CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'.h5')
226 logging.info(
'Creating model...')
230 input_shape = [PLANES, CELLS, VIEWS]
244 logging.info(
'Setting optimizer...')
246 opt = optimizers.SGD(lr=LEARNING_RATE, momentum=MOMENTUM, decay=DECAY, nesterov=
True)
250 logging.info(
'Compiling model...')
252 model.compile(loss=
'categorical_crossentropy', optimizer=opt, metrics=[
'accuracy'])
255 model.compile(loss='categorical_crossentropy', 257 metrics=['accuracy']) 267 **************************************** 268 *************** CALLBACKS ************** 269 **************************************** 274 logging.info(
'Configuring checkpointing...')
278 filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'.h5' 280 if VALIDATION_FRACTION > 0:
284 if CHECKPOINT_SAVE_MANY:
285 filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'-{epoch:02d}-{val_acc:.2f}.h5' 287 monitor_acc =
'val_acc' 288 monitor_loss =
'val_loss' 294 if CHECKPOINT_SAVE_MANY:
295 filepath = CHECKPOINT_PATH + CHECKPOINT_PREFIX +
'-{epoch:02d}-{acc:.2f}.h5' 298 monitor_loss =
'loss' 300 checkpoint = ModelCheckpoint(filepath, monitor=monitor_acc, verbose=1, save_best_only=CHECKPOINT_SAVE_BEST_ONLY, mode=
'max', period=CHECKPOINT_PERIOD)
304 logging.info(
'Configuring learning rate reducer...')
307 lr_reducer = ReduceLROnPlateau(monitor=monitor_loss, factor=0.1, cooldown=0, patience=3, min_lr=0.5e-6, verbose=1)
311 logging.info(
'Configuring early stopping...')
313 early_stopping = EarlyStopping(monitor=monitor_acc, patience=EARLY_STOPPING_PATIENCE, mode=
'auto')
317 csv_logger = CSVLogger(LOG_PATH + LOG_PREFIX +
'.log', append=RESUME)
325 logging.info(
'Setting callbacks...')
327 callbacks_list = [lr_reducer, checkpoint, early_stopping, csv_logger]
332 **************************************** 333 *************** TRAINING *************** 334 **************************************** 345 with
open(LOG_PATH + LOG_PREFIX +
'.log',
'r') as logfile: 349 initial_epoch =
int(re.search(
r'\d+', logfile.read().
split(
'\n')[-2]).group()) + 1
357 logging.info(
'RESUMING TRAINING...')
366 logging.info(
'STARTING TRAINING...')
368 if VALIDATION_FRACTION > 0:
372 model.fit_generator(generator = training_generator,
373 steps_per_epoch = len(partition[
'train'])//TRAIN_BATCH_SIZE,
374 validation_data = validation_generator,
375 validation_steps = len(partition[
'validation'])//VALIDATION_BATCH_SIZE,
377 class_weight = class_weights,
378 callbacks = callbacks_list,
379 initial_epoch = initial_epoch,
387 model.fit_generator(generator = training_generator,
388 steps_per_epoch = len(partition[
'train'])//TRAIN_BATCH_SIZE,
390 class_weight = class_weights,
391 callbacks = callbacks_list,
392 initial_epoch = initial_epoch,
def build_resnet_18(input_shape, num_outputs)
void split(std::string const &s, char c, OutIter dest)
def open(path, mode='r', buf=-1)