train.py
Go to the documentation of this file.
1 """
2 Train the infill networks.
3 """
4 
5 import os, datetime, random, argparse, itertools
6 import numpy as np
7 import yaml
8 
9 import torch
10 import torch.nn as nn
11 
12 from infill_loss import InfillLossInduction, InfillLossCollection
13 from model import UnetInduction, UnetCollection
14 
15 
16 def train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info):
17  overtain_cntr = 0
18  train_losses, valid_losses = [], []
19  summary = {}
20  now = datetime.datetime.now().strftime("%d%m%Y-%H%M%S")
21 
22  batchsizes_train = [batchsize]*(int((len(train_files)/batchsize)))
23  batchsizes_train.append(len(train_files) % batchsize)
24  batchsizes_valid = [batchsize]*(int((len(valid_files)/batchsize)))
25  batchsizes_valid.append(len(valid_files) % batchsize)
26  if batchsizes_train[-1] == 0:
27  batchsizes_train.pop()
28  if batchsizes_valid[-1] == 0:
29  batchsizes_valid.pop()
30 
31  for epoch in range(epochs):
32  model.train()
33 
34  epoch_running_train_loss = 0.0
35  running_loss = 0.0
36 
37  random.shuffle(train_files)
38  files_for_batches = np.split(np.array(train_files), [ sum(batchsizes_train[:i]) for i in range(1, len(batchsizes_train)) ])
39 
40  for idx, batch_files in enumerate(files_for_batches):
41  masked_tensor_batch_lst, true_tensor_batch_lst = [], []
42 
43  for batch_idx, filepath in enumerate(batch_files):
44  arr = np.load(filepath).T
45 
46  maskpattern = random.sample(maskpatterns, 1)[0]
47  offset = random.randint(1,info["width"] - 1) # Exclude offset = 0 for validation set.
48  offset_maskpattern = [ i + offset if (i + offset) < info["width"] else i - info["width"] + offset for i in maskpattern ]
49  arr_mask = np.copy(arr)
50  arr_mask[:, offset_maskpattern] = 0
51 
52  masked_tensor_batch_lst.append(torch.FloatTensor(arr_mask.reshape(1, *arr_mask.shape)))
53  true_tensor_batch_lst.append(torch.FloatTensor(arr.reshape(1, *arr.shape)))
54 
55  masked_tensor_batch = torch.stack(masked_tensor_batch_lst)
56  del masked_tensor_batch_lst
57  true_tensor_batch = torch.stack(true_tensor_batch_lst)
58  del true_tensor_batch_lst
59 
60  masked_tensor_batch = masked_tensor_batch.to(info["DEVICE"])
61  true_tensor_batch = true_tensor_batch.to(info["DEVICE"])
62 
63  info["optimizer"].zero_grad()
64  outputs = model(masked_tensor_batch)
65  loss = info["criterion"](outputs, true_tensor_batch, masked_tensor_batch)[5]
66  loss.backward()
67  info["optimizer"].step()
68 
69  del masked_tensor_batch
70  del true_tensor_batch
71 
72  running_loss += loss.item()
73  epoch_running_train_loss += loss.item()
74  if (idx + 1) % 5 == 0:
75  print('[{}, {:2.2%}] loss: {:.2f}'.format(epoch + 1, (idx*batchsize)/float(len(train_files)), running_loss/5))
76  running_loss = 0.0
77 
78 # adjust_learning_rate(info["optimizer"], epoch, info["lr"]) # lr decay
79 
80  train_losses.append(epoch_running_train_loss/len(files_for_batches))
81 
82  model.eval()
83 
84  running_loss = 0.0
85 
86  files_for_batches = np.split(np.array(valid_files), [ sum(batchsizes_valid[:i]) for i in range(1, len(batchsizes_valid)) ])
87  maskpattern_pool = itertools.cycle(maskpatterns)
88 
89  with torch.no_grad():
90  for batch_files in files_for_batches:
91  masked_tensor_batch_lst, true_tensor_batch_lst = [], []
92 
93  for batch_idx, filepath in enumerate(batch_files):
94  arr = np.load(filepath).T
95 
96  maskpattern = next(maskpattern_pool) # Use true mask patterns for validation
97  arr_mask = np.copy(arr)
98  arr_mask[:, maskpattern] = 0
99 
100  masked_tensor_batch_lst.append(torch.FloatTensor(arr_mask.reshape(1, *arr_mask.shape)))
101  true_tensor_batch_lst.append(torch.FloatTensor(arr.reshape(1, *arr.shape)))
102 
103  masked_tensor_batch = torch.stack(masked_tensor_batch_lst)
104  del masked_tensor_batch_lst
105  true_tensor_batch = torch.stack(true_tensor_batch_lst)
106  del true_tensor_batch_lst
107 
108  masked_tensor_batch = masked_tensor_batch.to(info["DEVICE"])
109  true_tensor_batch = true_tensor_batch.to(info["DEVICE"])
110 
111  outputs = model(masked_tensor_batch)
112  loss = info["criterion"](outputs, true_tensor_batch, masked_tensor_batch)[5]
113 
114  del masked_tensor_batch
115  del true_tensor_batch
116 
117  running_loss += loss.item()
118 
119  valid_losses.append(running_loss/len(files_for_batches))
120  print("Validation loss: {:.2f}".format(running_loss/len(files_for_batches)))
121 
122  summary['train losses'] = train_losses
123  summary['valid losses'] = valid_losses
124 
125  if epoch == 0:
126  torch.save(model.module.state_dict(), info["model_name"] + '.pth')
127  old_valid_loss = valid_losses[0]
128 
129  else:
130  if (valid_losses[-1] - old_valid_loss) < 0:
131  torch.save(model.module.state_dict(), info["model_name"] + '.pth')
132  old_valid_loss = valid_losses[-1]
133  overtain_cntr = 0
134  summary['best epoch'] = epoch
135  summary['best valid loss'] = valid_losses[-1]
136 
137  else:
138  overtain_cntr += 1
139 
140  if overtain_cntr > 5:
141  break
142 
143  with open('training_summary_{}.yaml'.format(now), 'w') as f:
144  yaml.dump(summary, f)
145 
146  print("best valid loss: {} (at epoch {})".format(summary['best valid loss'], summary['best epoch']))
147  print("train losees: {}\n".format(train_losses))
148  print("valid losses: {}\n".format(valid_losses))
149 
150 
151 def adjust_learning_rate(optimizer, epoch, lr):
152  lr = lr * (0.5 ** (epoch // 20))
153  for param_group in optimizer.param_groups:
154  param_group['lr'] = lr
155 
156 
157 def main(train_dir, valid_dir, collection, induction, epochs, batchsize, model_name):
158  DEVICE = torch.device("cuda:0")
159 
160  if collection:
161  model = UnetCollection()
162  criterion = InfillLossCollection().to(device=DEVICE)
163  maskpatterns = [[55, 66, 78, 81, 89, 238, 370],
164  [217, 219, 221, 223, 225, 227, 250, 251, 252, 361, 363, 365, 367, 369, 371],
165  [20, 95, 134, 147, 196, 351],
166  [2, 3, 25, 27, 29, 31, 33, 35, 289, 409, 411, 413, 415, 417, 419, 456],
167  [4, 13, 424, 436]]
168  width = 480
169 
170  elif induction:
171  model = UnetInduction()
172  criterion = InfillLossInduction().to(device=DEVICE)
173  maskpatterns = [[1, 2, 4, 94, 200, 202, 204, 206, 208, 325, 400, 401, 442, 447, 453, 455,
174  456, 472, 477, 571, 573],
175  [0, 1, 76, 191, 193, 195, 197, 199, 400, 734, 739, 746],
176  [114, 273, 401],
177  [181, 183, 301, 303, 701, 703, 781, 783],
178  [5, 151, 201, 241, 243, 257, 280, 303],
179  [212],
180  [0, 1, 238, 400, 648, 661],
181  [0, 21, 23, 341, 343, 781, 783],
182  [457, 560, 667, 784],
183  [163, 230, 417, 419, 423, 429, 477, 629, 639],
184  [1, 201, 281, 563]]
185  width = 800
186 
187  model = nn.DataParallel(model)
188  model.to(DEVICE)
189 
190  lr = 1.0e-4
191  momentum = 0.9
192  weight_decay = 1.0e-4
193  optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)
194 
195  train_files = [ os.path.join(train_dir, filename) for filename in os.listdir(train_dir) if filename.endswith(".npy") ]
196  valid_files = [ os.path.join(valid_dir, filename) for filename in os.listdir(valid_dir) if filename.endswith(".npy") ]
197 
198  info = {
199  "DEVICE" : DEVICE,
200  "criterion" : criterion,
201  "optimizer" : optimizer,
202  "lr" : lr,
203  "model_name" : model_name,
204  "width" : width
205  }
206 
207  train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info)
208 
209 
211  parser = argparse.ArgumentParser()
212 
213  parser.add_argument("train_dir")
214  parser.add_argument("valid_dir")
215 
216  group = parser.add_mutually_exclusive_group(required=True)
217  group.add_argument("--collection",action='store_true')
218  group.add_argument("--induction",action='store_true')
219 
220  parser.add_argument("-e", "--epochs", nargs='?', type=int, default=10, action='store', dest='EPOCHS')
221  parser.add_argument("-b", "--batchsize", nargs='?', type=int, default=12, action='store', dest='BATCHSIZE')
222  parser.add_argument("--model_name", nargs='?', type=str, action='store', dest='MODEL_NAME',
223  default="{}".format(datetime.datetime.now().strftime("_%d%m%Y-%H%M%S")))
224 
225  args = parser.parse_args()
226 
227  return (args.train_dir, args.valid_dir, args.collection, args.induction, args.EPOCHS,
228  args.BATCHSIZE, args.MODEL_NAME)
229 
230 
231 if __name__ == "__main__":
232  arguments = parse_arguments()
233  main(*arguments)
def main(train_dir, valid_dir, collection, induction, epochs, batchsize, model_name)
Definition: train.py:157
int open(const char *, int)
Opens a file descriptor.
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
Definition: qstring.cpp:11496
Definition: model.py:1
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info)
Definition: train.py:16
Definition: train.py:1
def adjust_learning_rate(optimizer, epoch, lr)
Definition: train.py:151
def parse_arguments()
Definition: train.py:210