2 Trace trained infill model to produce a TorchScript. Resulting .pt can then be loaded direclty into 11 from model
import UnetInduction, UnetCollection
13 def main(input_file, out_name, example_dir, collection):
21 DEVICE = torch.device(
"cpu")
22 model = model.to(DEVICE)
24 pretrained_dict = torch.load(input_file, map_location=
"cpu")
25 pretrained_dict = {key.replace(
"module.",
""): value
for key, value
in pretrained_dict.items()}
26 model.load_state_dict(pretrained_dict)
28 for filename
in os.listdir(example_dir):
29 if filename.endswith(
".npy"):
30 arr = np.load(os.path.join(example_dir, filename)).T
31 maskpattern = [114, 273, 401]
32 arr[:, maskpattern] = 0
33 example_img = torch.FloatTensor(arr.reshape(1, *arr.shape))
34 example_img = torch.stack([example_img])
38 traced_model = torch.jit.trace(model, example_img)
39 traced_model.save(out_name +
".pt")
40 print(
"TorchScript summary:\n{}".
format(traced_model))
44 parser = argparse.ArgumentParser()
46 parser.add_argument(
"input_file", help=
".pth file to be serialised")
47 parser.add_argument(
"output_name")
48 parser.add_argument(
"example_dir", help=
"Directory containing input data for the model being serialised")
50 group = parser.add_mutually_exclusive_group(required=
True)
51 group.add_argument(
"--collection",action=
'store_true')
52 group.add_argument(
"--induction",action=
'store_true')
54 args = parser.parse_args()
56 return (args.input_file, args.output_name, args.example_dir, args.collection)
59 if __name__ ==
"__main__":
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
def main(input_file, out_name, example_dir, collection)