107 def main(weights, test_dir, collection, induction, n, traces):
108 DEVICE = torch.device(
"cpu")
113 maskpatterns = [[55, 66, 78, 81, 89, 238, 370],
114 [217, 219, 221, 223, 225, 227, 250, 251, 252, 361, 363, 365, 367, 369, 371],
115 [20, 95, 134, 147, 196, 351],
116 [2, 3, 25, 27, 29, 31, 33, 35, 289, 409, 411, 413, 415, 417, 419, 456],
123 maskpatterns = [[1, 2, 4, 94, 200, 202, 204, 206, 208, 325, 400, 401, 442, 447, 453, 455,
124 456, 472, 477, 571, 573],
125 [0, 1, 76, 191, 193, 195, 197, 199, 400, 734, 739, 746],
127 [181, 183, 301, 303, 701, 703, 781, 783],
128 [5, 151, 201, 241, 243, 257, 280, 303],
130 [0, 1, 238, 400, 648, 661],
131 [0, 21, 23, 341, 343, 781, 783],
132 [457, 560, 667, 784],
133 [163, 230, 417, 419, 423, 429, 477, 629, 639],
136 model = model.to(DEVICE)
139 pretrained_dict = torch.load(weights, map_location=
"cpu")
140 pretrained_dict = {key.replace(
"module.",
""): value
for key, value
in pretrained_dict.items()}
141 model.load_state_dict(pretrained_dict)
145 "criterion" : criterion,
146 "maskpatterns" : maskpatterns
149 predict(model, test_dir, n, traces, info)
def main(weights, test_dir, collection, induction, n, traces)