make_tscript.py
Go to the documentation of this file.
1 """
2 Trace trained infill model to produce a TorchScript. Resulting .pt can then be loaded direclty into
3 C++.
4 """
5 
6 import os, argparse
7 
8 import torch
9 import numpy as np
10 
11 from model import UnetInduction, UnetCollection
12 
13 def main(input_file, out_name, example_dir, collection):
14 
15  if collection:
16  model = UnetCollection()
17 
18  else:
19  model = UnetInduction()
20 
21  DEVICE = torch.device("cpu")
22  model = model.to(DEVICE)
23  model.eval()
24  pretrained_dict = torch.load(input_file, map_location="cpu")
25  pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
26  model.load_state_dict(pretrained_dict)
27 
28  for filename in os.listdir(example_dir):
29  if filename.endswith(".npy"):
30  arr = np.load(os.path.join(example_dir, filename)).T
31  maskpattern = [114, 273, 401] # Example maskpattern
32  arr[:, maskpattern] = 0
33  example_img = torch.FloatTensor(arr.reshape(1, *arr.shape))
34  example_img = torch.stack([example_img])
35  break
36 
37  with torch.no_grad():
38  traced_model = torch.jit.trace(model, example_img)
39  traced_model.save(out_name + ".pt")
40  print("TorchScript summary:\n{}".format(traced_model))
41 
42 
44  parser = argparse.ArgumentParser()
45 
46  parser.add_argument("input_file", help=".pth file to be serialised")
47  parser.add_argument("output_name")
48  parser.add_argument("example_dir", help="Directory containing input data for the model being serialised")
49 
50  group = parser.add_mutually_exclusive_group(required=True)
51  group.add_argument("--collection",action='store_true')
52  group.add_argument("--induction",action='store_true')
53 
54  args = parser.parse_args()
55 
56  return (args.input_file, args.output_name, args.example_dir, args.collection)
57 
58 
59 if __name__ == "__main__":
60  arguments = parse_arguments()
61  main(*arguments)
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
Definition: qstring.cpp:11496
def parse_arguments()
Definition: make_tscript.py:43
def main(input_file, out_name, example_dir, collection)
Definition: make_tscript.py:13