2 Train the infill networks. 5 import os, datetime, random, argparse, itertools
12 from infill_loss
import InfillLossInduction, InfillLossCollection
13 from model
import UnetInduction, UnetCollection
16 def train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info):
18 train_losses, valid_losses = [], []
20 now = datetime.datetime.now().strftime(
"%d%m%Y-%H%M%S")
22 batchsizes_train = [batchsize]*(
int((len(train_files)/batchsize)))
23 batchsizes_train.append(len(train_files) % batchsize)
24 batchsizes_valid = [batchsize]*(
int((len(valid_files)/batchsize)))
25 batchsizes_valid.append(len(valid_files) % batchsize)
26 if batchsizes_train[-1] == 0:
27 batchsizes_train.pop()
28 if batchsizes_valid[-1] == 0:
29 batchsizes_valid.pop()
31 for epoch
in range(epochs):
34 epoch_running_train_loss = 0.0
37 random.shuffle(train_files)
38 files_for_batches = np.split(np.array(train_files), [ sum(batchsizes_train[:i])
for i
in range(1, len(batchsizes_train)) ])
40 for idx, batch_files
in enumerate(files_for_batches):
41 masked_tensor_batch_lst, true_tensor_batch_lst = [], []
43 for batch_idx, filepath
in enumerate(batch_files):
44 arr = np.load(filepath).T
46 maskpattern = random.sample(maskpatterns, 1)[0]
47 offset = random.randint(1,info[
"width"] - 1)
48 offset_maskpattern = [ i + offset
if (i + offset) < info[
"width"]
else i - info[
"width"] + offset
for i
in maskpattern ]
49 arr_mask = np.copy(arr)
50 arr_mask[:, offset_maskpattern] = 0
52 masked_tensor_batch_lst.append(torch.FloatTensor(arr_mask.reshape(1, *arr_mask.shape)))
53 true_tensor_batch_lst.append(torch.FloatTensor(arr.reshape(1, *arr.shape)))
55 masked_tensor_batch = torch.stack(masked_tensor_batch_lst)
56 del masked_tensor_batch_lst
57 true_tensor_batch = torch.stack(true_tensor_batch_lst)
58 del true_tensor_batch_lst
60 masked_tensor_batch = masked_tensor_batch.to(info[
"DEVICE"])
61 true_tensor_batch = true_tensor_batch.to(info[
"DEVICE"])
63 info[
"optimizer"].zero_grad()
64 outputs =
model(masked_tensor_batch)
65 loss = info[
"criterion"](outputs, true_tensor_batch, masked_tensor_batch)[5]
67 info[
"optimizer"].
step()
69 del masked_tensor_batch
72 running_loss += loss.item()
73 epoch_running_train_loss += loss.item()
74 if (idx + 1) % 5 == 0:
75 print(
'[{}, {:2.2%}] loss: {:.2f}'.
format(epoch + 1, (idx*batchsize)/
float(len(train_files)), running_loss/5))
80 train_losses.append(epoch_running_train_loss/len(files_for_batches))
86 files_for_batches = np.split(np.array(valid_files), [ sum(batchsizes_valid[:i])
for i
in range(1, len(batchsizes_valid)) ])
87 maskpattern_pool = itertools.cycle(maskpatterns)
90 for batch_files
in files_for_batches:
91 masked_tensor_batch_lst, true_tensor_batch_lst = [], []
93 for batch_idx, filepath
in enumerate(batch_files):
94 arr = np.load(filepath).T
96 maskpattern = next(maskpattern_pool)
97 arr_mask = np.copy(arr)
98 arr_mask[:, maskpattern] = 0
100 masked_tensor_batch_lst.append(torch.FloatTensor(arr_mask.reshape(1, *arr_mask.shape)))
101 true_tensor_batch_lst.append(torch.FloatTensor(arr.reshape(1, *arr.shape)))
103 masked_tensor_batch = torch.stack(masked_tensor_batch_lst)
104 del masked_tensor_batch_lst
105 true_tensor_batch = torch.stack(true_tensor_batch_lst)
106 del true_tensor_batch_lst
108 masked_tensor_batch = masked_tensor_batch.to(info[
"DEVICE"])
109 true_tensor_batch = true_tensor_batch.to(info[
"DEVICE"])
111 outputs =
model(masked_tensor_batch)
112 loss = info[
"criterion"](outputs, true_tensor_batch, masked_tensor_batch)[5]
114 del masked_tensor_batch
115 del true_tensor_batch
117 running_loss += loss.item()
119 valid_losses.append(running_loss/len(files_for_batches))
120 print(
"Validation loss: {:.2f}".
format(running_loss/len(files_for_batches)))
122 summary[
'train losses'] = train_losses
123 summary[
'valid losses'] = valid_losses
126 torch.save(model.module.state_dict(), info[
"model_name"] +
'.pth')
127 old_valid_loss = valid_losses[0]
130 if (valid_losses[-1] - old_valid_loss) < 0:
131 torch.save(model.module.state_dict(), info[
"model_name"] +
'.pth')
132 old_valid_loss = valid_losses[-1]
134 summary[
'best epoch'] = epoch
135 summary[
'best valid loss'] = valid_losses[-1]
140 if overtain_cntr > 5:
143 with
open(
'training_summary_{}.yaml'.
format(now),
'w')
as f:
144 yaml.dump(summary, f)
146 print(
"best valid loss: {} (at epoch {})".
format(summary[
'best valid loss'], summary[
'best epoch']))
147 print(
"train losees: {}\n".
format(train_losses))
148 print(
"valid losses: {}\n".
format(valid_losses))
152 lr = lr * (0.5 ** (epoch // 20))
153 for param_group
in optimizer.param_groups:
154 param_group[
'lr'] = lr
157 def main(train_dir, valid_dir, collection, induction, epochs, batchsize, model_name):
158 DEVICE = torch.device(
"cuda:0")
163 maskpatterns = [[55, 66, 78, 81, 89, 238, 370],
164 [217, 219, 221, 223, 225, 227, 250, 251, 252, 361, 363, 365, 367, 369, 371],
165 [20, 95, 134, 147, 196, 351],
166 [2, 3, 25, 27, 29, 31, 33, 35, 289, 409, 411, 413, 415, 417, 419, 456],
173 maskpatterns = [[1, 2, 4, 94, 200, 202, 204, 206, 208, 325, 400, 401, 442, 447, 453, 455,
174 456, 472, 477, 571, 573],
175 [0, 1, 76, 191, 193, 195, 197, 199, 400, 734, 739, 746],
177 [181, 183, 301, 303, 701, 703, 781, 783],
178 [5, 151, 201, 241, 243, 257, 280, 303],
180 [0, 1, 238, 400, 648, 661],
181 [0, 21, 23, 341, 343, 781, 783],
182 [457, 560, 667, 784],
183 [163, 230, 417, 419, 423, 429, 477, 629, 639],
187 model = nn.DataParallel(model)
192 weight_decay = 1.0e-4
193 optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)
195 train_files = [ os.path.join(train_dir, filename)
for filename
in os.listdir(train_dir)
if filename.endswith(
".npy") ]
196 valid_files = [ os.path.join(valid_dir, filename)
for filename
in os.listdir(valid_dir)
if filename.endswith(
".npy") ]
200 "criterion" : criterion,
201 "optimizer" : optimizer,
203 "model_name" : model_name,
207 train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info)
211 parser = argparse.ArgumentParser()
213 parser.add_argument(
"train_dir")
214 parser.add_argument(
"valid_dir")
216 group = parser.add_mutually_exclusive_group(required=
True)
217 group.add_argument(
"--collection",action=
'store_true')
218 group.add_argument(
"--induction",action=
'store_true')
220 parser.add_argument(
"-e",
"--epochs", nargs=
'?', type=int, default=10, action=
'store', dest=
'EPOCHS')
221 parser.add_argument(
"-b",
"--batchsize", nargs=
'?', type=int, default=12, action=
'store', dest=
'BATCHSIZE')
222 parser.add_argument(
"--model_name", nargs=
'?', type=str, action=
'store', dest=
'MODEL_NAME',
223 default=
"{}".
format(datetime.datetime.now().strftime(
"_%d%m%Y-%H%M%S")))
225 args = parser.parse_args()
227 return (args.train_dir, args.valid_dir, args.collection, args.induction, args.EPOCHS,
228 args.BATCHSIZE, args.MODEL_NAME)
231 if __name__ ==
"__main__":
def main(train_dir, valid_dir, collection, induction, epochs, batchsize, model_name)
int open(const char *, int)
Opens a file descriptor.
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
def train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info)
def adjust_learning_rate(optimizer, epoch, lr)