keras_to_tensorflow.py
Go to the documentation of this file.
1 
2 # coding: utf-8
3 
4 # In[ ]:
5 
6 """
7 Copyright (c) 2017, by the Authors: Amir H. Abdi
8 This software is freely available under the MIT Public License.
9 Please see the License file in the root for details.
10 
11 The following code snippet will convert the keras model file,
12 which is saved using model.save('kerasmodel_weight_file'),
13 to the freezed .pb tensorflow weight file which holds both the
14 network architecture and its associated weights.
15 """;
16 
17 
18 # In[ ]:
19 
20 '''
21 Input arguments:
22 
23 num_output: this value has nothing to do with the number of classes, batch_size, etc.,
24 and it is mostly equal to 1. If the network is a **multi-stream network**
25 (forked network with multiple outputs), set the value to the number of outputs.
26 
27 quantize: if set to True, use the quantize feature of Tensorflow
28 (https://www.tensorflow.org/performance/quantization) [default: False]
29 
30 use_theano: Thaeno and Tensorflow implement convolution in different ways.
31 When using Keras with Theano backend, the order is set to 'channels_first'.
32 This feature is not fully tested, and doesn't work with quantizization [default: False]
33 
34 input_fld: directory holding the keras weights file [default: .]
35 
36 output_fld: destination directory to save the tensorflow files [default: .]
37 
38 input_model_file: name of the input weight file [default: 'model.h5']
39 
40 output_model_file: name of the output weight file [default: args.input_model_file + '.pb']
41 
42 graph_def: if set to True, will write the graph definition as an ascii file [default: False]
43 
44 output_graphdef_file: if graph_def is set to True, the file name of the
45 graph definition [default: model.ascii]
46 
47 output_node_prefix: the prefix to use for output nodes. [default: output_node]
48 
49 '''
50 
51 
52 # Parse input arguments
53 
54 # In[ ]:
55 
56 import argparse
57 parser = argparse.ArgumentParser(description='set input arguments')
58 parser.add_argument('-input_fld', action="store",
59  dest='input_fld', type=str, default='.')
60 parser.add_argument('-output_fld', action="store",
61  dest='output_fld', type=str, default='')
62 parser.add_argument('-input_model_file', action="store",
63  dest='input_model_file', type=str, default='model.h5')
64 parser.add_argument('-output_model_file', action="store",
65  dest='output_model_file', type=str, default='')
66 parser.add_argument('-output_graphdef_file', action="store",
67  dest='output_graphdef_file', type=str, default='model.ascii')
68 parser.add_argument('-num_outputs', action="store",
69  dest='num_outputs', type=int, default=1)
70 parser.add_argument('-graph_def', action="store",
71  dest='graph_def', type=bool, default=False)
72 parser.add_argument('-output_node_prefix', action="store",
73  dest='output_node_prefix', type=str, default='output_')
74 parser.add_argument('-quantize', action="store",
75  dest='quantize', type=bool, default=False)
76 parser.add_argument('-theano_backend', action="store",
77  dest='theano_backend', type=bool, default=False)
78 parser.add_argument('-f')
79 args = parser.parse_args()
80 #parser.print_help()
81 print('input args: ', args)
82 
83 if args.theano_backend is True and args.quantize is True:
84  raise ValueError("Quantize feature does not work with theano backend.")
85 
86 
87 # initialize
88 
89 # In[ ]:
90 
91 from keras.models import load_model
92 import tensorflow as tf
93 #from pathlib import Path
94 from keras import backend as K
95 
96 output_fld = args.input_fld if args.output_fld == '' else args.output_fld
97 if args.output_model_file == '':
98  args.output_model_file = str(args.input_model_file) + '.pb'
99 #'.'.mkdir(parents=True, exist_ok=True)
100 weight_file_path = str(args.input_model_file)
101 
102 
103 # Load keras model and rename output
104 
105 # In[ ]:
106 
107 K.set_learning_phase(0)
108 if args.theano_backend:
109  K.set_image_data_format('channels_first')
110 else:
111  K.set_image_data_format('channels_last')
112 
113 try:
114  net_model = load_model(weight_file_path)
115 except ValueError as err:
116  print('''Input file specified ({}) only holds the weights, and not the model defenition.
117  Save the model using mode.save(filename.h5) which will contain the network architecture
118  as well as its weights.
119  If the model is saved using model.save_weights(filename.h5), the model architecture is
120  expected to be saved separately in a json format and loaded prior to loading the weights.
121  Check the keras documentation for more details (https://keras.io/getting-started/faq/)'''
122  .format(weight_file_path))
123  raise err
124 num_output = args.num_outputs
125 pred = [None]*num_output
126 pred_node_names = [None]*num_output
127 
128 '''
129 for i in range(num_output):
130  pred_node_names[i] = args.output_node_prefix+str(i)
131  pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
132 '''
133 
134 for i in range(num_output):
135  pred_node_names[i] = args.output_node_prefix
136  if i == 0:
137  pred_node_names[i] += 'is_antineutrino'
138  elif i == 1:
139  pred_node_names[i] += 'flavour'
140  elif i == 2:
141  pred_node_names[i] += 'interaction'
142  elif i == 3:
143  pred_node_names[i] += 'protons'
144  elif i == 4:
145  pred_node_names[i] += 'pions'
146  elif i == 5:
147  pred_node_names[i] += 'pizeros'
148  else:
149  pred_node_names[i] += 'neutrons'
150 
151  pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
152 print('output nodes names are: ', pred_node_names)
153 
154 
155 # [optional] write graph definition in ascii
156 
157 # In[ ]:
158 
159 sess = K.get_session()
160 
161 if args.graph_def:
162  f = args.output_graphdef_file
163  tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
164  print('saved the graph definition in ascii format at: ', '.' / f)
165 
166 
167 # convert variables to constants and save
168 
169 # In[ ]:
170 
171 from tensorflow.python.framework import graph_util
172 from tensorflow.python.framework import graph_io
173 if args.quantize:
174  from tensorflow.tools.graph_transforms import TransformGraph
175  transforms = ["quantize_weights", "quantize_nodes"]
176  transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
177  constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
178 else:
179  constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
180 graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
181 print('saved the freezed graph (ready for inference) at: ', str(args.output_model_file))
182 
static bool format(QChar::Decomposition tag, QString &str, int index, int len)
Definition: qstring.cpp:11496
def load_model(name)