predict.py
Go to the documentation of this file.
1 """
2 Plot images that have been infilled by trained models.
3 """
4 
5 
6 import os, sys, time, argparse, itertools
7 import numpy as np
8 from matplotlib import pyplot as plt
9 from mpl_toolkits.axes_grid1 import make_axes_locatable
10 from matplotlib.lines import Line2D
11 
12 import torch
13 
14 from infill_loss import InfillLossInduction, InfillLossCollection
15 from model import UnetInduction, UnetCollection
16 # from loss_dense_infill_collection import DenseInfillLoss
17 
18 
19 def predict(model, test_dir, N, trace, info):
20  test_masked_lst, test_true_lst, filenames = [], [], [],
21  maskpattern_pool = itertools.cycle(info["maskpatterns"])
22 
23  for filename in os.listdir(test_dir):
24  if filename.endswith(".npy"):
25  arr = np.load(os.path.join(test_dir, filename)).T
26 
27  maskpattern = next(maskpattern_pool)
28  arr_mask = np.copy(arr)
29  arr_mask[:, maskpattern] = 0
30 
31  test_masked_lst.append(torch.FloatTensor(arr_mask.reshape(1, *arr_mask.shape)))
32  test_true_lst.append(torch.FloatTensor(arr.reshape(1, *arr.shape)))
33  filenames.append(filename)
34 
35  if len(test_masked_lst) >= N:
36  break
37 
38  model.eval()
39  with torch.no_grad():
40  for idx, masked in enumerate(test_masked_lst):
41  print(filenames[idx])
42  masked_tensor = torch.stack([masked])
43  true_tensor = torch.stack([test_true_lst[idx]])
44 
45  masked_tensor.to(info["DEVICE"])
46  true_tensor.to(info["DEVICE"])
47 
48  start = time.time()
49  outputs = model(masked_tensor)
50  end = time.time()
51  print("Inference time:{:.1f}s".format(end - start))
52 
53  loss = info["criterion"](outputs, true_tensor, masked_tensor)[5]
54  print("Loss: {}".format(loss))
55 
56  img_pred = outputs.detach().numpy()[0, 0, :, :].T
57  img_true = true_tensor.detach().numpy()[0, 0, :, :].T
58  img_masked = masked_tensor.detach().numpy()[0, 0, :, :].T
59 
60  dead_ch = [ idx for idx, col in enumerate(img_masked) if np.all(col == 0) ]
61  print("Dead channels: {}".format(dead_ch))
62 
63  mask = np.zeros_like(img_masked)
64  mask[np.array(dead_ch), :] = 1
65  img_masked = np.ma.masked_array(img_masked, mask)
66  img_infill = np.ma.masked_array(img_pred, np.logical_not(mask))
67 
68  # Plots truth image
69  # fig, ax = plt.subplots()
70  # fig.set_size_inches(16, 10)
71  # im1 = ax.imshow(img_true.T, aspect='auto', cmap='coolwarm', vmin=-30, vmax=30, interpolation='None')
72  # plt.show()
73 
74  plt.rc('font', family='serif')
75  fig, ax = plt.subplots()
76  fig.set_size_inches(16, 10)
77  im1 = ax.imshow(img_masked.T, aspect='auto', cmap='coolwarm', vmin=-30, vmax=30, interpolation='None')
78  im2 = ax.imshow(img_infill.T, aspect='auto', cmap='PRGn', vmin=-30, vmax=30, interpolation='None')
79  plt.show()
80 
81  if trace:
82  for ch in dead_ch:
83  print("Channel {}".format(ch))
84  tick_adc_true = img_true[ch, :]
85  tick_adc_pred = img_pred[ch, :]
86  tick = np.arange(1, 6001)
87 
88  plt.hist(tick, bins=len(tick), weights=tick_adc_true, histtype='step', label="True", linewidth=0.7)
89  plt.hist(tick, bins=len(tick), weights=tick_adc_pred, histtype='step', label="Network", linewidth=0.7)
90  plt.xlim(1,6001)
91  plt.xlabel("time tick", fontsize=20)
92  plt.ylabel("ADC", fontsize=20)
93  plt.title("Channel {}".format(ch))
94  ax = plt.gca()
95  ax.tick_params(axis='both', which='major', labelsize=16)
96  ax.tick_params(axis='both', which='minor', labelsize=16)
97  # What is this stuff doing?
98  tx = ax.yaxis.get_offset_text()
99  tx.set_fontsize(20)
100  handles, labels = ax.get_legend_handles_labels()
101  new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]
102  plt.legend(handles=new_handles, labels=labels, prop={'size': 20})
103  plt.show()
104 
105 
106 
107 def main(weights, test_dir, collection, induction, n, traces):
108  DEVICE = torch.device("cpu")
109 
110  if collection:
111  model = UnetCollection()
112  criterion = InfillLossCollection().to(device=DEVICE)
113  maskpatterns = [[55, 66, 78, 81, 89, 238, 370],
114  [217, 219, 221, 223, 225, 227, 250, 251, 252, 361, 363, 365, 367, 369, 371],
115  [20, 95, 134, 147, 196, 351],
116  [2, 3, 25, 27, 29, 31, 33, 35, 289, 409, 411, 413, 415, 417, 419, 456],
117  [4, 13, 424, 436]]
118 
119  elif induction:
120  model = UnetInduction()
121  criterion = InfillLossInduction().to(device=DEVICE)
122  # Change these patters to the ROP ones from decodeDigits
123  maskpatterns = [[1, 2, 4, 94, 200, 202, 204, 206, 208, 325, 400, 401, 442, 447, 453, 455,
124  456, 472, 477, 571, 573],
125  [0, 1, 76, 191, 193, 195, 197, 199, 400, 734, 739, 746],
126  [114, 273, 401],
127  [181, 183, 301, 303, 701, 703, 781, 783],
128  [5, 151, 201, 241, 243, 257, 280, 303],
129  [212],
130  [0, 1, 238, 400, 648, 661],
131  [0, 21, 23, 341, 343, 781, 783],
132  [457, 560, 667, 784],
133  [163, 230, 417, 419, 423, 429, 477, 629, 639],
134  [1, 201, 281, 563]]
135 
136  model = model.to(DEVICE)
137 
138  # Needed to infer on cpu if model was trained using DataParallel
139  pretrained_dict = torch.load(weights, map_location="cpu")
140  pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
141  model.load_state_dict(pretrained_dict)
142 
143  info = {
144  "DEVICE" : DEVICE,
145  "criterion" : criterion,
146  "maskpatterns" : maskpatterns
147  }
148 
149  predict(model, test_dir, n, traces, info)
150 
151 
153  parser = argparse.ArgumentParser()
154 
155  parser.add_argument("weights", help="Takes a .pth file")
156  parser.add_argument("test_dir")
157 
158  group = parser.add_mutually_exclusive_group(required=True)
159  group.add_argument("--collection",action="store_true")
160  group.add_argument("--induction",action="store_true")
161 
162  parser.add_argument("-n", nargs="?", type=int, action="store", default=5, dest="N",
163  help="Number of files to do inference on")
164  parser.add_argument("-t", "--traces", action="store_true",
165  help="Plot true and predicted ADC by time tick for each dead channel that has been infilled")
166 
167  args = parser.parse_args()
168 
169  return (args.weights, args.test_dir, args.collection, args.induction, args.N, args.traces)
170 
171 
172 if __name__ == "__main__":
173  arguments = parse_arguments()
174  main(*arguments)
175 
176 
177 
178 
179 
180 
181 
182 
def parse_arguments()
Definition: predict.py:152
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
Definition: qstring.cpp:11496
Definition: model.py:1
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def predict(model, test_dir, N, trace, info)
Definition: predict.py:19
def main(weights, test_dir, collection, induction, n, traces)
Definition: predict.py:107