16 def train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info):
18 train_losses, valid_losses = [], []
20 now = datetime.datetime.now().strftime(
"%d%m%Y-%H%M%S")
22 batchsizes_train = [batchsize]*(
int((len(train_files)/batchsize)))
23 batchsizes_train.append(len(train_files) % batchsize)
24 batchsizes_valid = [batchsize]*(
int((len(valid_files)/batchsize)))
25 batchsizes_valid.append(len(valid_files) % batchsize)
26 if batchsizes_train[-1] == 0:
27 batchsizes_train.pop()
28 if batchsizes_valid[-1] == 0:
29 batchsizes_valid.pop()
31 for epoch
in range(epochs):
34 epoch_running_train_loss = 0.0
37 random.shuffle(train_files)
38 files_for_batches = np.split(np.array(train_files), [ sum(batchsizes_train[:i])
for i
in range(1, len(batchsizes_train)) ])
40 for idx, batch_files
in enumerate(files_for_batches):
41 masked_tensor_batch_lst, true_tensor_batch_lst = [], []
43 for batch_idx, filepath
in enumerate(batch_files):
44 arr = np.load(filepath).T
46 maskpattern = random.sample(maskpatterns, 1)[0]
47 offset = random.randint(1,info[
"width"] - 1)
48 offset_maskpattern = [ i + offset
if (i + offset) < info[
"width"]
else i - info[
"width"] + offset
for i
in maskpattern ]
49 arr_mask = np.copy(arr)
50 arr_mask[:, offset_maskpattern] = 0
52 masked_tensor_batch_lst.append(torch.FloatTensor(arr_mask.reshape(1, *arr_mask.shape)))
53 true_tensor_batch_lst.append(torch.FloatTensor(arr.reshape(1, *arr.shape)))
55 masked_tensor_batch = torch.stack(masked_tensor_batch_lst)
56 del masked_tensor_batch_lst
57 true_tensor_batch = torch.stack(true_tensor_batch_lst)
58 del true_tensor_batch_lst
60 masked_tensor_batch = masked_tensor_batch.to(info[
"DEVICE"])
61 true_tensor_batch = true_tensor_batch.to(info[
"DEVICE"])
63 info[
"optimizer"].zero_grad()
64 outputs =
model(masked_tensor_batch)
65 loss = info[
"criterion"](outputs, true_tensor_batch, masked_tensor_batch)[5]
67 info[
"optimizer"].
step()
69 del masked_tensor_batch
72 running_loss += loss.item()
73 epoch_running_train_loss += loss.item()
74 if (idx + 1) % 5 == 0:
75 print(
'[{}, {:2.2%}] loss: {:.2f}'.
format(epoch + 1, (idx*batchsize)/
float(len(train_files)), running_loss/5))
80 train_losses.append(epoch_running_train_loss/len(files_for_batches))
86 files_for_batches = np.split(np.array(valid_files), [ sum(batchsizes_valid[:i])
for i
in range(1, len(batchsizes_valid)) ])
87 maskpattern_pool = itertools.cycle(maskpatterns)
90 for batch_files
in files_for_batches:
91 masked_tensor_batch_lst, true_tensor_batch_lst = [], []
93 for batch_idx, filepath
in enumerate(batch_files):
94 arr = np.load(filepath).T
96 maskpattern = next(maskpattern_pool)
97 arr_mask = np.copy(arr)
98 arr_mask[:, maskpattern] = 0
100 masked_tensor_batch_lst.append(torch.FloatTensor(arr_mask.reshape(1, *arr_mask.shape)))
101 true_tensor_batch_lst.append(torch.FloatTensor(arr.reshape(1, *arr.shape)))
103 masked_tensor_batch = torch.stack(masked_tensor_batch_lst)
104 del masked_tensor_batch_lst
105 true_tensor_batch = torch.stack(true_tensor_batch_lst)
106 del true_tensor_batch_lst
108 masked_tensor_batch = masked_tensor_batch.to(info[
"DEVICE"])
109 true_tensor_batch = true_tensor_batch.to(info[
"DEVICE"])
111 outputs =
model(masked_tensor_batch)
112 loss = info[
"criterion"](outputs, true_tensor_batch, masked_tensor_batch)[5]
114 del masked_tensor_batch
115 del true_tensor_batch
117 running_loss += loss.item()
119 valid_losses.append(running_loss/len(files_for_batches))
120 print(
"Validation loss: {:.2f}".
format(running_loss/len(files_for_batches)))
122 summary[
'train losses'] = train_losses
123 summary[
'valid losses'] = valid_losses
126 torch.save(model.module.state_dict(), info[
"model_name"] +
'.pth')
127 old_valid_loss = valid_losses[0]
130 if (valid_losses[-1] - old_valid_loss) < 0:
131 torch.save(model.module.state_dict(), info[
"model_name"] +
'.pth')
132 old_valid_loss = valid_losses[-1]
134 summary[
'best epoch'] = epoch
135 summary[
'best valid loss'] = valid_losses[-1]
140 if overtain_cntr > 5:
143 with
open(
'training_summary_{}.yaml'.
format(now),
'w')
as f:
144 yaml.dump(summary, f)
146 print(
"best valid loss: {} (at epoch {})".
format(summary[
'best valid loss'], summary[
'best epoch']))
147 print(
"train losees: {}\n".
format(train_losses))
148 print(
"valid losses: {}\n".
format(valid_losses))
int open(const char *, int)
Opens a file descriptor.
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
def train(model, train_files, valid_files, maskpatterns, epochs, batchsize, info)