Functions | Variables
make_tscript Namespace Reference

Functions

def main (input_file, out_name, example_dir, collection)
 
def parse_arguments ()
 

Variables

 arguments = parse_arguments()
 

Detailed Description

Trace trained infill model to produce a TorchScript. Resulting .pt can then be loaded direclty into
C++.

Function Documentation

def make_tscript.main (   input_file,
  out_name,
  example_dir,
  collection 
)

Definition at line 13 of file make_tscript.py.

13 def main(input_file, out_name, example_dir, collection):
14 
15  if collection:
16  model = UnetCollection()
17 
18  else:
19  model = UnetInduction()
20 
21  DEVICE = torch.device("cpu")
22  model = model.to(DEVICE)
23  model.eval()
24  pretrained_dict = torch.load(input_file, map_location="cpu")
25  pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
26  model.load_state_dict(pretrained_dict)
27 
28  for filename in os.listdir(example_dir):
29  if filename.endswith(".npy"):
30  arr = np.load(os.path.join(example_dir, filename)).T
31  maskpattern = [114, 273, 401] # Example maskpattern
32  arr[:, maskpattern] = 0
33  example_img = torch.FloatTensor(arr.reshape(1, *arr.shape))
34  example_img = torch.stack([example_img])
35  break
36 
37  with torch.no_grad():
38  traced_model = torch.jit.trace(model, example_img)
39  traced_model.save(out_name + ".pt")
40  print("TorchScript summary:\n{}".format(traced_model))
41 
42 
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
Definition: qstring.cpp:11496
def main(input_file, out_name, example_dir, collection)
Definition: make_tscript.py:13
def make_tscript.parse_arguments ( )

Definition at line 43 of file make_tscript.py.

44  parser = argparse.ArgumentParser()
45 
46  parser.add_argument("input_file", help=".pth file to be serialised")
47  parser.add_argument("output_name")
48  parser.add_argument("example_dir", help="Directory containing input data for the model being serialised")
49 
50  group = parser.add_mutually_exclusive_group(required=True)
51  group.add_argument("--collection",action='store_true')
52  group.add_argument("--induction",action='store_true')
53 
54  args = parser.parse_args()
55 
56  return (args.input_file, args.output_name, args.example_dir, args.collection)
57 
58 
def parse_arguments()
Definition: make_tscript.py:43

Variable Documentation

make_tscript.arguments = parse_arguments()

Definition at line 60 of file make_tscript.py.