Functions | |
| def | load_model (name) |
| def | save_model (model, name) |
| def | generate_data_generator (generator, X, Y1, Y2, b) |
Variables | |
| parser = argparse.ArgumentParser(description='Run CNN training on patches with a few different hyperparameter sets.') | |
| help | |
| default | |
| args = parser.parse_args() | |
| config = read_config(args.config) | |
| configuration ############################# More... | |
| cfg_name = args.model | |
| out_name = args.output | |
| CNN_INPUT_DIR = config['training_on_patches']['input_dir'] | |
| PATCH_SIZE_W | |
| PATCH_SIZE_D | |
| img_rows | |
| img_cols | |
| batch_size = config['training_on_patches']['batch_size'] | |
| nb_epoch = config['training_on_patches']['nb_epoch'] | |
| nb_classes = config['training_on_patches']['nb_classes'] | |
| model = load_model(cfg_name) | |
| CNN commpilation ###########################. More... | |
| sgd = SGD(lr=0.002, decay=1e-5, momentum=0.9, nesterov=True) | |
| optimizer | |
| loss | |
| loss_weights | |
| n_training = count_events(CNN_INPUT_DIR, 'training') | |
| read data sets ############################ More... | |
| X_train = np.zeros((n_training, PATCH_SIZE_W, PATCH_SIZE_D, 1), dtype=np.float32) | |
| EmTrkNone_train = np.zeros((n_training, 3), dtype=np.int32) | |
| Michel_train = np.zeros((n_training, 1), dtype=np.int32) | |
| int | ntot = 0 |
| list | subdirs = [f for f in os.listdir(CNN_INPUT_DIR) if 'training' in f] |
| list | filesX = [f for f in os.listdir(CNN_INPUT_DIR + '/' + dirname) if '_x.npy' in f] |
| fnameY = fnameX.replace('_x.npy', '_y.npy') | |
| dataX = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameX) | |
| dataY = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameY) | |
| n = dataY.shape[0] | |
| n_testing = count_events(CNN_INPUT_DIR, 'testing') | |
| X_test = np.zeros((n_testing, PATCH_SIZE_W, PATCH_SIZE_D, 1), dtype=np.float32) | |
| EmTrkNone_test = np.zeros((n_testing, 3), dtype=np.int32) | |
| Michel_test = np.zeros((n_testing, 1), dtype=np.int32) | |
| datagen | |
| training ############################### More... | |
| h | |
| score | |
| def train_cnn_continue_augmented_data.generate_data_generator | ( | generator, | |
| X, | |||
| Y1, | |||
| Y2, | |||
| b | |||
| ) |
Definition at line 143 of file train_cnn_continue_augmented_data.py.
| def train_cnn_continue_augmented_data.load_model | ( | name | ) |
Definition at line 32 of file train_cnn_continue_augmented_data.py.
| def train_cnn_continue_augmented_data.save_model | ( | model, | |
| name | |||
| ) |
Definition at line 38 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.args = parser.parse_args() |
Definition at line 7 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.batch_size = config['training_on_patches']['batch_size'] |
Definition at line 60 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.cfg_name = args.model |
Definition at line 52 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.CNN_INPUT_DIR = config['training_on_patches']['input_dir'] |
Definition at line 55 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.config = read_config(args.config) |
configuration #############################
Definition at line 50 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.datagen |
training ###############################
Definition at line 133 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.dataX = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameX) |
Definition at line 90 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.dataY = np.load(CNN_INPUT_DIR + '/' + dirname + '/' + fnameY) |
Definition at line 93 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.default |
Definition at line 3 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.EmTrkNone_test = np.zeros((n_testing, 3), dtype=np.int32) |
Definition at line 103 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.EmTrkNone_train = np.zeros((n_training, 3), dtype=np.int32) |
Definition at line 77 of file train_cnn_continue_augmented_data.py.
| list train_cnn_continue_augmented_data.filesX = [f for f in os.listdir(CNN_INPUT_DIR + '/' + dirname) if '_x.npy' in f] |
Definition at line 86 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.fnameY = fnameX.replace('_x.npy', '_y.npy') |
Definition at line 89 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.h |
Definition at line 152 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.help |
Definition at line 3 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.img_cols |
Definition at line 58 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.img_rows |
Definition at line 58 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.loss |
Definition at line 71 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.loss_weights |
Definition at line 72 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.Michel_test = np.zeros((n_testing, 1), dtype=np.int32) |
Definition at line 104 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.Michel_train = np.zeros((n_training, 1), dtype=np.int32) |
Definition at line 78 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.model = load_model(cfg_name) |
CNN commpilation ###########################.
Definition at line 67 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.n = dataY.shape[0] |
Definition at line 94 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.n_testing = count_events(CNN_INPUT_DIR, 'testing') |
Definition at line 101 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.n_training = count_events(CNN_INPUT_DIR, 'training') |
read data sets ############################
Definition at line 75 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.nb_classes = config['training_on_patches']['nb_classes'] |
Definition at line 62 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.nb_epoch = config['training_on_patches']['nb_epoch'] |
Definition at line 61 of file train_cnn_continue_augmented_data.py.
| int train_cnn_continue_augmented_data.ntot = 0 |
Definition at line 81 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.optimizer |
Definition at line 70 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.out_name = args.output |
Definition at line 53 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.parser = argparse.ArgumentParser(description='Run CNN training on patches with a few different hyperparameter sets.') |
Definition at line 2 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.PATCH_SIZE_D |
Definition at line 57 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.PATCH_SIZE_W |
Definition at line 57 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.score |
Definition at line 164 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.sgd = SGD(lr=0.002, decay=1e-5, momentum=0.9, nesterov=True) |
Definition at line 69 of file train_cnn_continue_augmented_data.py.
| list train_cnn_continue_augmented_data.subdirs = [f for f in os.listdir(CNN_INPUT_DIR) if 'training' in f] |
Definition at line 82 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.X_test = np.zeros((n_testing, PATCH_SIZE_W, PATCH_SIZE_D, 1), dtype=np.float32) |
Definition at line 102 of file train_cnn_continue_augmented_data.py.
| train_cnn_continue_augmented_data.X_train = np.zeros((n_training, PATCH_SIZE_W, PATCH_SIZE_D, 1), dtype=np.float32) |
Definition at line 76 of file train_cnn_continue_augmented_data.py.
1.8.11