Functions | Variables
train Namespace Reference

Functions

def train (model, train_files, valid_files, maskpatterns, epochs, batchsize, info)
 
def adjust_learning_rate (optimizer, epoch, lr)
 
def main (train_dir, valid_dir, collection, induction, epochs, batchsize, model_name)
 
def parse_arguments ()
 

Variables

string __version__ = '1.0'
 
string __author__ = 'Saul Alonso-Monsalve'
 
string __email__ = "saul.alonso.monsalve@cern.ch"
 
 sess = tf.Session()
 
 init = tf.global_variables_initializer()
 
 stream
 
 stdout
 
 level
 
 config = configparser.ConfigParser()
 
 SEED = int(config['random']['seed'])
 
 SHUFFLE = ast.literal_eval(config['random']['shuffle'])
 
 IMAGES_PATH = config['images']['path']
 
 VIEWS = int(config['images']['views'])
 
 PLANES = int(config['images']['planes'])
 
 CELLS = int(config['images']['cells'])
 
 STANDARDIZE = ast.literal_eval(config['images']['standardize'])
 
 DATASET_PATH = config['dataset']['path']
 
 PARTITION_PREFIX = config['dataset']['partition_prefix']
 
 LABELS_PREFIX = config['dataset']['labels_prefix']
 
 LOG_PATH = config['log']['path']
 
 LOG_PREFIX = config['log']['prefix']
 
 ARCHITECTURE = config['model']['architecture']
 
 CHECKPOINT_PATH = config['model']['checkpoint_path']
 
 CHECKPOINT_PREFIX = config['model']['checkpoint_prefix']
 
 CHECKPOINT_SAVE_MANY = ast.literal_eval(config['model']['checkpoint_save_many'])
 
 CHECKPOINT_SAVE_BEST_ONLY = ast.literal_eval(config['model']['checkpoint_save_best_only'])
 
 CHECKPOINT_PERIOD = int(config['model']['checkpoint_period'])
 
 PARALLELIZE = ast.literal_eval(config['model']['parallelize'])
 
 GPUS = int(config['model']['gpus'])
 
 PRINT_SUMMARY = ast.literal_eval(config['model']['print_summary'])
 
 BRANCHES = ast.literal_eval(config['model']['branches'])
 
 OUTPUTS = int(config['model']['outputs'])
 
 RESUME = ast.literal_eval(config['train']['resume'])
 
 LEARNING_RATE = float(config['train']['lr'])
 
 MOMENTUM = float(config['train']['momentum'])
 
 DECAY = float(config['train']['decay'])
 
 TRAIN_BATCH_SIZE = int(config['train']['batch_size'])
 
 EPOCHS = int(config['train']['epochs'])
 
 EARLY_STOPPING_PATIENCE = int(config['train']['early_stopping_patience'])
 
 WEIGHTED_LOSS_FUNCTION = ast.literal_eval(config['train']['weighted_loss_function'])
 
 CLASS_WEIGHTS_PREFIX = config['train']['class_weights_prefix']
 
 MAX_QUEUE_SIZE = int(config['train']['max_queue_size'])
 
 VALIDATION_FRACTION = float(config['validation']['fraction'])
 
 VALIDATION_BATCH_SIZE = int(config['validation']['batch_size'])
 
dictionary TRAIN_PARAMS
 
dictionary VALIDATION_PARAMS
 
dictionary partition = {'train' : [], 'validation' : [], 'test' : []}
 
dictionary labels = {}
 
 class_weights = pickle.load(class_weights_file)
 
 training_generator = DataGenerator(**TRAIN_PARAMS).generate(labels, partition['train'], True)
 
 validation_generator = DataGenerator(**VALIDATION_PARAMS).generate(labels, partition['validation'], True)
 
 opt = optimizers.SGD(lr=LEARNING_RATE, momentum=MOMENTUM, decay=DECAY, nesterov=True)
 
list files = [f for f in os.listdir(CHECKPOINT_PATH) if os.path.isfile(os.path.join(CHECKPOINT_PATH, f))]
 
 reverse
 
 r = re.compile(CHECKPOINT_PREFIX[1:] + '-.*-.*.h5')
 
string filename = CHECKPOINT_PATH+'/'
 
 sequential_model
 
list input_shape = [PLANES, CELLS, 1]
 
 aux_model = networks.create_model(network=ARCHITECTURE, input_shape=input_shape)
 
int weight_decay = 1
 
list x = [None]*OUTPUTS
 
 use_bias
 
 False
 
 kernel_regularizer
 
 activation
 
 name
 
 model = multi_gpu_model(sequential_model, gpus=GPUS, cpu_relocation=True)
 
 num_outputs = len(sequential_model.output_names)
 
dictionary model_loss = {'categories':my_losses.masked_loss_categorical}
 
 loss
 
 optimizer
 
 metrics
 
string filepath = CHECKPOINT_PATH+CHECKPOINT_PREFIX+'.h5'
 
string monitor_acc = 'val_acc'
 
string monitor_loss = 'val_loss'
 
 checkpoint = my_callbacks.ModelCheckpointDetached(filepath, monitor=monitor_acc, verbose=1, save_best_only=CHECKPOINT_SAVE_BEST_ONLY, save_weights_only=False, mode='max', period=CHECKPOINT_PERIOD)
 
 lr_reducer = ReduceLROnPlateau(monitor=monitor_acc, mode='max', factor=0.1, cooldown=0, patience=10, min_lr=0.5e-6, verbose=1)
 
 early_stopping = EarlyStopping(monitor=monitor_acc, patience=EARLY_STOPPING_PATIENCE, mode='auto')
 
 csv_logger = CSVLogger(LOG_PATH + LOG_PREFIX + '.log', append=RESUME)
 
 my_callback = my_callbacks.MyCallback()
 
list callbacks_list = [lr_reducer, checkpoint, early_stopping, csv_logger, my_callback]
 
int initial_epoch = int(re.search(r'\d+', logfile.read().split('\n')[-2]).group())+1
 
 validation_data = validation_generator
 
 validation_steps = len(partition['validation'])//VALIDATION_BATCH_SIZE
 
 generator
 
 steps_per_epoch
 
 epochs
 
 class_weight
 
 callbacks
 
 max_queue_size
 
 verbose
 
 use_multiprocessing
 
 workers
 
 arguments = parse_arguments()
 

Detailed Description

This is the train module.
Train the infill networks.

Function Documentation

def train.adjust_learning_rate (   optimizer,
  epoch,
  lr 
)

Definition at line 151 of file train.py.

151 def adjust_learning_rate(optimizer, epoch, lr):
152  lr = lr * (0.5 ** (epoch // 20))
153  for param_group in optimizer.param_groups:
154  param_group['lr'] = lr
155 
156 
def adjust_learning_rate(optimizer, epoch, lr)
Definition: train.py:151
def train.main (   train_dir,
  valid_dir,
  collection,
  induction,
  epochs,
  batchsize,
  model_name 
)

Definition at line 157 of file train.py.

157 def main(train_dir, valid_dir, collection, induction, epochs, batchsize, model_name):
158  DEVICE = torch.device("cuda:0")
159 
160  if collection:
161  model = UnetCollection()
162  criterion = InfillLossCollection().to(device=DEVICE)
163  maskpatterns = [[55, 66, 78, 81, 89, 238, 370],
164  [217, 219, 221, 223, 225, 227, 250, 251, 252, 361, 363, 365, 367, 369, 371],
165  [20, 95, 134, 147, 196, 351],
166  [2, 3, 25, 27, 29, 31, 33, 35, 289, 409, 411, 413, 415, 417, 419, 456],
167  [4, 13, 424, 436]]
168  width = 480
169 
170  elif induction:
171  model = UnetInduction()
172  criterion = InfillLossInduction().to(device=DEVICE)
173  maskpatterns = [[1, 2, 4, 94, 200, 202, 204, 206, 208, 325, 400, 401, 442, 447, 453, 455,
174  456, 472, 477, 571, 573],
175  [0, 1, 76, 191, 193, 195, 197, 199, 400, 734, 739, 746],
176  [114, 273, 401],
177  [181, 183, 301, 303, 701, 703, 781, 783],
178  [5, 151, 201, 241, 243, 257, 280, 303],
179  [212],
180  [0, 1, 238, 400, 648, 661],
181  [0, 21, 23, 341, 343, 781, 783],
182  [457, 560, 667, 784],
183  [163, 230, 417, 419, 423, 429, 477, 629, 639],
184  [1, 201, 281, 563]]
185  width = 800
186 
187  model = nn.DataParallel(model)
188  model.to(DEVICE)
189 
190  lr = 1.0e-4
191  momentum = 0.9
192  weight_decay = 1.0e-4
193  optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)
194 
195  train_files = [ os.path.join(train_dir, filename) for filename in os.listdir(train_dir) if filename.endswith(".npy") ]
196  valid_files = [ os.path.join(valid_dir, filename) for filename in os.listdir(valid_dir) if filename.endswith(".npy") ]
197 
198  info = {
199  "DEVICE" : DEVICE,
200  "criterion" : criterion,
201  "optimizer" : optimizer,
202  "lr" : lr,
203  "model_name" : model_name,
204  "width" : width
205  }
206 
207  train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info)
208 
209 
def main(train_dir, valid_dir, collection, induction, epochs, batchsize, model_name)
Definition: train.py:157
Definition: train.py:1
def train.parse_arguments ( )

Definition at line 210 of file train.py.

211  parser = argparse.ArgumentParser()
212 
213  parser.add_argument("train_dir")
214  parser.add_argument("valid_dir")
215 
216  group = parser.add_mutually_exclusive_group(required=True)
217  group.add_argument("--collection",action='store_true')
218  group.add_argument("--induction",action='store_true')
219 
220  parser.add_argument("-e", "--epochs", nargs='?', type=int, default=10, action='store', dest='EPOCHS')
221  parser.add_argument("-b", "--batchsize", nargs='?', type=int, default=12, action='store', dest='BATCHSIZE')
222  parser.add_argument("--model_name", nargs='?', type=str, action='store', dest='MODEL_NAME',
223  default="{}".format(datetime.datetime.now().strftime("_%d%m%Y-%H%M%S")))
224 
225  args = parser.parse_args()
226 
227  return (args.train_dir, args.valid_dir, args.collection, args.induction, args.EPOCHS,
228  args.BATCHSIZE, args.MODEL_NAME)
229 
230 
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
Definition: qstring.cpp:11496
def parse_arguments()
Definition: train.py:210
def train.train (   model,
  train_files,
  valid_files,
  maskpatterns,
  epochs,
  batchsize,
  info 
)

Definition at line 16 of file train.py.

16 def train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info):
17  overtain_cntr = 0
18  train_losses, valid_losses = [], []
19  summary = {}
20  now = datetime.datetime.now().strftime("%d%m%Y-%H%M%S")
21 
22  batchsizes_train = [batchsize]*(int((len(train_files)/batchsize)))
23  batchsizes_train.append(len(train_files) % batchsize)
24  batchsizes_valid = [batchsize]*(int((len(valid_files)/batchsize)))
25  batchsizes_valid.append(len(valid_files) % batchsize)
26  if batchsizes_train[-1] == 0:
27  batchsizes_train.pop()
28  if batchsizes_valid[-1] == 0:
29  batchsizes_valid.pop()
30 
31  for epoch in range(epochs):
32  model.train()
33 
34  epoch_running_train_loss = 0.0
35  running_loss = 0.0
36 
37  random.shuffle(train_files)
38  files_for_batches = np.split(np.array(train_files), [ sum(batchsizes_train[:i]) for i in range(1, len(batchsizes_train)) ])
39 
40  for idx, batch_files in enumerate(files_for_batches):
41  masked_tensor_batch_lst, true_tensor_batch_lst = [], []
42 
43  for batch_idx, filepath in enumerate(batch_files):
44  arr = np.load(filepath).T
45 
46  maskpattern = random.sample(maskpatterns, 1)[0]
47  offset = random.randint(1,info["width"] - 1) # Exclude offset = 0 for validation set.
48  offset_maskpattern = [ i + offset if (i + offset) < info["width"] else i - info["width"] + offset for i in maskpattern ]
49  arr_mask = np.copy(arr)
50  arr_mask[:, offset_maskpattern] = 0
51 
52  masked_tensor_batch_lst.append(torch.FloatTensor(arr_mask.reshape(1, *arr_mask.shape)))
53  true_tensor_batch_lst.append(torch.FloatTensor(arr.reshape(1, *arr.shape)))
54 
55  masked_tensor_batch = torch.stack(masked_tensor_batch_lst)
56  del masked_tensor_batch_lst
57  true_tensor_batch = torch.stack(true_tensor_batch_lst)
58  del true_tensor_batch_lst
59 
60  masked_tensor_batch = masked_tensor_batch.to(info["DEVICE"])
61  true_tensor_batch = true_tensor_batch.to(info["DEVICE"])
62 
63  info["optimizer"].zero_grad()
64  outputs = model(masked_tensor_batch)
65  loss = info["criterion"](outputs, true_tensor_batch, masked_tensor_batch)[5]
66  loss.backward()
67  info["optimizer"].step()
68 
69  del masked_tensor_batch
70  del true_tensor_batch
71 
72  running_loss += loss.item()
73  epoch_running_train_loss += loss.item()
74  if (idx + 1) % 5 == 0:
75  print('[{}, {:2.2%}] loss: {:.2f}'.format(epoch + 1, (idx*batchsize)/float(len(train_files)), running_loss/5))
76  running_loss = 0.0
77 
78 # adjust_learning_rate(info["optimizer"], epoch, info["lr"]) # lr decay
79 
80  train_losses.append(epoch_running_train_loss/len(files_for_batches))
81 
82  model.eval()
83 
84  running_loss = 0.0
85 
86  files_for_batches = np.split(np.array(valid_files), [ sum(batchsizes_valid[:i]) for i in range(1, len(batchsizes_valid)) ])
87  maskpattern_pool = itertools.cycle(maskpatterns)
88 
89  with torch.no_grad():
90  for batch_files in files_for_batches:
91  masked_tensor_batch_lst, true_tensor_batch_lst = [], []
92 
93  for batch_idx, filepath in enumerate(batch_files):
94  arr = np.load(filepath).T
95 
96  maskpattern = next(maskpattern_pool) # Use true mask patterns for validation
97  arr_mask = np.copy(arr)
98  arr_mask[:, maskpattern] = 0
99 
100  masked_tensor_batch_lst.append(torch.FloatTensor(arr_mask.reshape(1, *arr_mask.shape)))
101  true_tensor_batch_lst.append(torch.FloatTensor(arr.reshape(1, *arr.shape)))
102 
103  masked_tensor_batch = torch.stack(masked_tensor_batch_lst)
104  del masked_tensor_batch_lst
105  true_tensor_batch = torch.stack(true_tensor_batch_lst)
106  del true_tensor_batch_lst
107 
108  masked_tensor_batch = masked_tensor_batch.to(info["DEVICE"])
109  true_tensor_batch = true_tensor_batch.to(info["DEVICE"])
110 
111  outputs = model(masked_tensor_batch)
112  loss = info["criterion"](outputs, true_tensor_batch, masked_tensor_batch)[5]
113 
114  del masked_tensor_batch
115  del true_tensor_batch
116 
117  running_loss += loss.item()
118 
119  valid_losses.append(running_loss/len(files_for_batches))
120  print("Validation loss: {:.2f}".format(running_loss/len(files_for_batches)))
121 
122  summary['train losses'] = train_losses
123  summary['valid losses'] = valid_losses
124 
125  if epoch == 0:
126  torch.save(model.module.state_dict(), info["model_name"] + '.pth')
127  old_valid_loss = valid_losses[0]
128 
129  else:
130  if (valid_losses[-1] - old_valid_loss) < 0:
131  torch.save(model.module.state_dict(), info["model_name"] + '.pth')
132  old_valid_loss = valid_losses[-1]
133  overtain_cntr = 0
134  summary['best epoch'] = epoch
135  summary['best valid loss'] = valid_losses[-1]
136 
137  else:
138  overtain_cntr += 1
139 
140  if overtain_cntr > 5:
141  break
142 
143  with open('training_summary_{}.yaml'.format(now), 'w') as f:
144  yaml.dump(summary, f)
145 
146  print("best valid loss: {} (at epoch {})".format(summary['best valid loss'], summary['best epoch']))
147  print("train losees: {}\n".format(train_losses))
148  print("valid losses: {}\n".format(valid_losses))
149 
150 
int open(const char *, int)
Opens a file descriptor.
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
Definition: qstring.cpp:11496
Definition: model.py:1
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info)
Definition: train.py:16

Variable Documentation

string train.__author__ = 'Saul Alonso-Monsalve'
private

Definition at line 6 of file train.py.

string train.__email__ = "saul.alonso.monsalve@cern.ch"
private

Definition at line 7 of file train.py.

string train.__version__ = '1.0'
private

Definition at line 5 of file train.py.

train.activation

Definition at line 280 of file train.py.

train.ARCHITECTURE = config['model']['architecture']

Definition at line 89 of file train.py.

train.arguments = parse_arguments()

Definition at line 232 of file train.py.

Definition at line 271 of file train.py.

train.BRANCHES = ast.literal_eval(config['model']['branches'])

Definition at line 98 of file train.py.

train.callbacks

Definition at line 474 of file train.py.

list train.callbacks_list = [lr_reducer, checkpoint, early_stopping, csv_logger, my_callback]

Definition at line 430 of file train.py.

train.CELLS = int(config['images']['cells'])

Definition at line 73 of file train.py.

train.checkpoint = my_callbacks.ModelCheckpointDetached(filepath, monitor=monitor_acc, verbose=1, save_best_only=CHECKPOINT_SAVE_BEST_ONLY, save_weights_only=False, mode='max', period=CHECKPOINT_PERIOD)

Definition at line 395 of file train.py.

train.CHECKPOINT_PATH = config['model']['checkpoint_path']

Definition at line 90 of file train.py.

train.CHECKPOINT_PERIOD = int(config['model']['checkpoint_period'])

Definition at line 94 of file train.py.

train.CHECKPOINT_PREFIX = config['model']['checkpoint_prefix']

Definition at line 91 of file train.py.

train.CHECKPOINT_SAVE_BEST_ONLY = ast.literal_eval(config['model']['checkpoint_save_best_only'])

Definition at line 93 of file train.py.

train.CHECKPOINT_SAVE_MANY = ast.literal_eval(config['model']['checkpoint_save_many'])

Definition at line 92 of file train.py.

train.class_weight

Definition at line 473 of file train.py.

train.class_weights = pickle.load(class_weights_file)

Definition at line 165 of file train.py.

train.CLASS_WEIGHTS_PREFIX = config['train']['class_weights_prefix']

Definition at line 111 of file train.py.

train.config = configparser.ConfigParser()

Definition at line 55 of file train.py.

train.csv_logger = CSVLogger(LOG_PATH + LOG_PREFIX + '.log', append=RESUME)

Definition at line 415 of file train.py.

train.DATASET_PATH = config['dataset']['path']

Definition at line 78 of file train.py.

train.DECAY = float(config['train']['decay'])

Definition at line 106 of file train.py.

train.early_stopping = EarlyStopping(monitor=monitor_acc, patience=EARLY_STOPPING_PATIENCE, mode='auto')

Definition at line 411 of file train.py.

train.EARLY_STOPPING_PATIENCE = int(config['train']['early_stopping_patience'])

Definition at line 109 of file train.py.

train.EPOCHS = int(config['train']['epochs'])

Definition at line 108 of file train.py.

train.epochs

Definition at line 472 of file train.py.

train.False

Definition at line 279 of file train.py.

string train.filename = CHECKPOINT_PATH+'/'

Definition at line 213 of file train.py.

string train.filepath = CHECKPOINT_PATH+CHECKPOINT_PREFIX+'.h5'

Definition at line 371 of file train.py.

list train.files = [f for f in os.listdir(CHECKPOINT_PATH) if os.path.isfile(os.path.join(CHECKPOINT_PATH, f))]

Definition at line 206 of file train.py.

train.generator

Definition at line 468 of file train.py.

train.GPUS = int(config['model']['gpus'])

Definition at line 96 of file train.py.

train.IMAGES_PATH = config['images']['path']

Definition at line 70 of file train.py.

train.init = tf.global_variables_initializer()

Definition at line 42 of file train.py.

int train.initial_epoch = int(re.search(r'\d+', logfile.read().split('\n')[-2]).group())+1

Definition at line 445 of file train.py.

list train.input_shape = [PLANES, CELLS, 1]

Definition at line 234 of file train.py.

train.kernel_regularizer

Definition at line 279 of file train.py.

train.labels = {}

Definition at line 151 of file train.py.

train.LABELS_PREFIX = config['dataset']['labels_prefix']

Definition at line 80 of file train.py.

train.LEARNING_RATE = float(config['train']['lr'])

Definition at line 104 of file train.py.

train.level

Definition at line 53 of file train.py.

train.LOG_PATH = config['log']['path']

Definition at line 84 of file train.py.

train.LOG_PREFIX = config['log']['prefix']

Definition at line 85 of file train.py.

train.loss

Definition at line 346 of file train.py.

train.lr_reducer = ReduceLROnPlateau(monitor=monitor_acc, mode='max', factor=0.1, cooldown=0, patience=10, min_lr=0.5e-6, verbose=1)

Definition at line 405 of file train.py.

train.MAX_QUEUE_SIZE = int(config['train']['max_queue_size'])

Definition at line 112 of file train.py.

train.max_queue_size

Definition at line 476 of file train.py.

train.metrics

Definition at line 350 of file train.py.

train.model = multi_gpu_model(sequential_model, gpus=GPUS, cpu_relocation=True)

Definition at line 313 of file train.py.

dictionary train.model_loss = {'categories':my_losses.masked_loss_categorical}

Definition at line 328 of file train.py.

train.MOMENTUM = float(config['train']['momentum'])

Definition at line 105 of file train.py.

string train.monitor_acc = 'val_acc'

Definition at line 378 of file train.py.

string train.monitor_loss = 'val_loss'

Definition at line 379 of file train.py.

train.my_callback = my_callbacks.MyCallback()

Definition at line 419 of file train.py.

train.name

Definition at line 280 of file train.py.

train.num_outputs = len(sequential_model.output_names)

Definition at line 314 of file train.py.

train.opt = optimizers.SGD(lr=LEARNING_RATE, momentum=MOMENTUM, decay=DECAY, nesterov=True)

Definition at line 196 of file train.py.

train.optimizer

Definition at line 349 of file train.py.

train.OUTPUTS = int(config['model']['outputs'])

Definition at line 99 of file train.py.

train.PARALLELIZE = ast.literal_eval(config['model']['parallelize'])

Definition at line 95 of file train.py.

train.partition = {'train' : [], 'validation' : [], 'test' : []}

Definition at line 150 of file train.py.

train.PARTITION_PREFIX = config['dataset']['partition_prefix']

Definition at line 79 of file train.py.

train.PLANES = int(config['images']['planes'])

Definition at line 72 of file train.py.

train.PRINT_SUMMARY = ast.literal_eval(config['model']['print_summary'])

Definition at line 97 of file train.py.

train.r = re.compile(CHECKPOINT_PREFIX[1:] + '-.*-.*.h5')

Definition at line 209 of file train.py.

train.RESUME = ast.literal_eval(config['train']['resume'])

Definition at line 103 of file train.py.

train.reverse

Definition at line 207 of file train.py.

train.SEED = int(config['random']['seed'])

Definition at line 60 of file train.py.

train.sequential_model
Initial value:
1 = load_model(filename,
2  custom_objects={'tf':tf,
3  'masked_loss':my_losses.masked_loss,
4  'multitask_loss':my_losses.multitask_loss,
5  'masked_loss_binary':my_losses.masked_loss_binary,
6  'masked_loss_categorical':my_losses.masked_loss_categorical})
def load_model(name)

Definition at line 214 of file train.py.

train.sess = tf.Session()

Definition at line 41 of file train.py.

train.SHUFFLE = ast.literal_eval(config['random']['shuffle'])

Definition at line 66 of file train.py.

train.STANDARDIZE = ast.literal_eval(config['images']['standardize'])

Definition at line 74 of file train.py.

train.stdout

Definition at line 53 of file train.py.

train.steps_per_epoch

Definition at line 469 of file train.py.

train.stream

Definition at line 53 of file train.py.

train.TRAIN_BATCH_SIZE = int(config['train']['batch_size'])

Definition at line 107 of file train.py.

dictionary train.TRAIN_PARAMS
Initial value:
1 = {'planes':PLANES,
2  'cells':CELLS,
3  'views':VIEWS,
4  'batch_size':TRAIN_BATCH_SIZE,
5  'branches':BRANCHES,
6  'outputs': OUTPUTS,
7  'images_path':IMAGES_PATH,
8  'standardize':STANDARDIZE,
9  'shuffle':SHUFFLE}

Definition at line 121 of file train.py.

train.training_generator = DataGenerator(**TRAIN_PARAMS).generate(labels, partition['train'], True)

Definition at line 182 of file train.py.

train.use_bias

Definition at line 279 of file train.py.

train.use_multiprocessing

Definition at line 478 of file train.py.

train.VALIDATION_BATCH_SIZE = int(config['validation']['batch_size'])

Definition at line 117 of file train.py.

train.validation_data = validation_generator

Definition at line 459 of file train.py.

train.VALIDATION_FRACTION = float(config['validation']['fraction'])

Definition at line 116 of file train.py.

train.validation_generator = DataGenerator(**VALIDATION_PARAMS).generate(labels, partition['validation'], True)

Definition at line 183 of file train.py.

dictionary train.VALIDATION_PARAMS
Initial value:
1 = {'planes':PLANES,
2  'cells':CELLS,
3  'views':VIEWS,
4  'batch_size':VALIDATION_BATCH_SIZE,
5  'branches':BRANCHES,
6  'outputs': OUTPUTS,
7  'images_path':IMAGES_PATH,
8  'standardize':STANDARDIZE,
9  'shuffle':SHUFFLE}

Definition at line 133 of file train.py.

train.validation_steps = len(partition['validation'])//VALIDATION_BATCH_SIZE

Definition at line 460 of file train.py.

train.verbose

Definition at line 477 of file train.py.

train.VIEWS = int(config['images']['views'])

Definition at line 71 of file train.py.

int train.weight_decay = 1

Definition at line 274 of file train.py.

train.WEIGHTED_LOSS_FUNCTION = ast.literal_eval(config['train']['weighted_loss_function'])

Definition at line 110 of file train.py.

train.workers

Definition at line 479 of file train.py.

list train.x = [None]*OUTPUTS

Definition at line 276 of file train.py.