2 Plot images that have been infilled by trained models. 6 import os, sys, time, argparse, itertools
8 from matplotlib
import pyplot
as plt
9 from mpl_toolkits.axes_grid1
import make_axes_locatable
10 from matplotlib.lines
import Line2D
14 from infill_loss
import InfillLossInduction, InfillLossCollection
15 from model
import UnetInduction, UnetCollection
19 def predict(model, test_dir, N, trace, info):
20 test_masked_lst, test_true_lst, filenames = [], [], [],
21 maskpattern_pool = itertools.cycle(info[
"maskpatterns"])
23 for filename
in os.listdir(test_dir):
24 if filename.endswith(
".npy"):
25 arr = np.load(os.path.join(test_dir, filename)).T
27 maskpattern = next(maskpattern_pool)
28 arr_mask = np.copy(arr)
29 arr_mask[:, maskpattern] = 0
31 test_masked_lst.append(torch.FloatTensor(arr_mask.reshape(1, *arr_mask.shape)))
32 test_true_lst.append(torch.FloatTensor(arr.reshape(1, *arr.shape)))
33 filenames.append(filename)
35 if len(test_masked_lst) >= N:
40 for idx, masked
in enumerate(test_masked_lst):
42 masked_tensor = torch.stack([masked])
43 true_tensor = torch.stack([test_true_lst[idx]])
45 masked_tensor.to(info[
"DEVICE"])
46 true_tensor.to(info[
"DEVICE"])
49 outputs =
model(masked_tensor)
51 print(
"Inference time:{:.1f}s".
format(end - start))
53 loss = info[
"criterion"](outputs, true_tensor, masked_tensor)[5]
54 print(
"Loss: {}".
format(loss))
56 img_pred = outputs.detach().
numpy()[0, 0, :, :].T
57 img_true = true_tensor.detach().
numpy()[0, 0, :, :].T
58 img_masked = masked_tensor.detach().
numpy()[0, 0, :, :].T
60 dead_ch = [ idx
for idx, col
in enumerate(img_masked)
if np.all(col == 0) ]
61 print(
"Dead channels: {}".
format(dead_ch))
63 mask = np.zeros_like(img_masked)
64 mask[np.array(dead_ch), :] = 1
65 img_masked = np.ma.masked_array(img_masked, mask)
66 img_infill = np.ma.masked_array(img_pred, np.logical_not(mask))
74 plt.rc(
'font', family=
'serif')
75 fig, ax = plt.subplots()
76 fig.set_size_inches(16, 10)
77 im1 = ax.imshow(img_masked.T, aspect=
'auto', cmap=
'coolwarm', vmin=-30, vmax=30, interpolation=
'None')
78 im2 = ax.imshow(img_infill.T, aspect=
'auto', cmap=
'PRGn', vmin=-30, vmax=30, interpolation=
'None')
83 print(
"Channel {}".
format(ch))
84 tick_adc_true = img_true[ch, :]
85 tick_adc_pred = img_pred[ch, :]
86 tick = np.arange(1, 6001)
88 plt.hist(tick, bins=len(tick), weights=tick_adc_true, histtype=
'step', label=
"True", linewidth=0.7)
89 plt.hist(tick, bins=len(tick), weights=tick_adc_pred, histtype=
'step', label=
"Network", linewidth=0.7)
91 plt.xlabel(
"time tick", fontsize=20)
92 plt.ylabel(
"ADC", fontsize=20)
93 plt.title(
"Channel {}".
format(ch))
95 ax.tick_params(axis=
'both', which=
'major', labelsize=16)
96 ax.tick_params(axis=
'both', which=
'minor', labelsize=16)
98 tx = ax.yaxis.get_offset_text()
100 handles, labels = ax.get_legend_handles_labels()
101 new_handles = [Line2D([], [], c=h.get_edgecolor())
for h
in handles]
102 plt.legend(handles=new_handles, labels=labels, prop={
'size': 20})
107 def main(weights, test_dir, collection, induction, n, traces):
108 DEVICE = torch.device(
"cpu")
113 maskpatterns = [[55, 66, 78, 81, 89, 238, 370],
114 [217, 219, 221, 223, 225, 227, 250, 251, 252, 361, 363, 365, 367, 369, 371],
115 [20, 95, 134, 147, 196, 351],
116 [2, 3, 25, 27, 29, 31, 33, 35, 289, 409, 411, 413, 415, 417, 419, 456],
123 maskpatterns = [[1, 2, 4, 94, 200, 202, 204, 206, 208, 325, 400, 401, 442, 447, 453, 455,
124 456, 472, 477, 571, 573],
125 [0, 1, 76, 191, 193, 195, 197, 199, 400, 734, 739, 746],
127 [181, 183, 301, 303, 701, 703, 781, 783],
128 [5, 151, 201, 241, 243, 257, 280, 303],
130 [0, 1, 238, 400, 648, 661],
131 [0, 21, 23, 341, 343, 781, 783],
132 [457, 560, 667, 784],
133 [163, 230, 417, 419, 423, 429, 477, 629, 639],
136 model = model.to(DEVICE)
139 pretrained_dict = torch.load(weights, map_location=
"cpu")
140 pretrained_dict = {key.replace(
"module.",
""): value
for key, value
in pretrained_dict.items()}
141 model.load_state_dict(pretrained_dict)
145 "criterion" : criterion,
146 "maskpatterns" : maskpatterns
149 predict(model, test_dir, n, traces, info)
153 parser = argparse.ArgumentParser()
155 parser.add_argument(
"weights", help=
"Takes a .pth file")
156 parser.add_argument(
"test_dir")
158 group = parser.add_mutually_exclusive_group(required=
True)
159 group.add_argument(
"--collection",action=
"store_true")
160 group.add_argument(
"--induction",action=
"store_true")
162 parser.add_argument(
"-n", nargs=
"?", type=int, action=
"store", default=5, dest=
"N",
163 help=
"Number of files to do inference on")
164 parser.add_argument(
"-t",
"--traces", action=
"store_true",
165 help=
"Plot true and predicted ADC by time tick for each dead channel that has been infilled")
167 args = parser.parse_args()
169 return (args.weights, args.test_dir, args.collection, args.induction, args.N, args.traces)
172 if __name__ ==
"__main__":
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
def predict(model, test_dir, N, trace, info)
def main(weights, test_dir, collection, induction, n, traces)