save_tf_proto.py
Go to the documentation of this file.
1 import argparse
2 parser = argparse.ArgumentParser(description='Save model to TF protobuf')
3 parser.add_argument('-m', '--model', help="Keras TF model", default='model')
4 parser.add_argument('-o', '--output', help="TF graph", default='tf_graph.proto')
5 parser.add_argument('-n', '--nname', help="Use node as output if name contains", default='cnn_output')
6 parser.add_argument('-g', '--gpu', help="Which GPU index", default='0')
7 args = parser.parse_args()
8 
9 from keras import backend as K
10 from keras.models import model_from_json
11 from keras.optimizers import SGD
12 
13 from tensorflow.python.framework.graph_util import convert_variables_to_constants
14 import tensorflow as tf
15 
16 import os
17 os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
18 
19 def load_model(name):
20  with open(name + '_architecture.json') as f:
21  model = model_from_json(f.read())
22  model.load_weights(name + '_weights.h5')
23  return model
24 
25 K.set_learning_phase(0)
26 m = load_model(args.model)
27 
28 nnames = [n.name for n in K.get_session().graph.as_graph_def().node]
29 
30 lastnode = None
31 currentname = ''
32 output_nodes = []
33 for i, n in enumerate(nnames):
34  basename = n[:n.find('/')]
35  print i, n, basename
36  if ('_netout' in n) or (args.nname in n):
37  if (not lastnode is None) and (basename != currentname):
38  output_nodes.append(lastnode)
39  currentname = basename
40  lastnode = n
41 
42 if not lastnode is None:
43  output_nodes.append(lastnode) # last of last output
44 
45 if len(output_nodes) == 0:
46  print 'Cannot find output node'
47  exit(1)
48 print 'Output node names found:', output_nodes
49 
50 minimal_graph = convert_variables_to_constants(K.get_session(), K.get_session().graph.as_graph_def(), output_nodes)
51 
52 nnames = [n.name for n in minimal_graph.node]
53 print nnames
54 
55 tf.train.write_graph(minimal_graph, '.', args.output + '.pb', as_text=False)
56 
57 print 'all done!'
int open(const char *, int)
Opens a file descriptor.
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def load_model(name)