freeze_graph.py
Go to the documentation of this file.
1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Converts checkpoint variables into Const ops in a standalone GraphDef file.
16 
17 This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
18 variable values stored in a checkpoint file, and output a GraphDef with all of
19 the variable ops converted into const ops containing the values of the
20 variables.
21 
22 It's useful to do this when we need to load a single file in C++, especially in
23 environments like mobile or embedded where we may not have access to the
24 RestoreTensor ops and file loading calls that they rely on.
25 
26 An example of command-line usage is:
27 bazel build tensorflow/python/tools:freeze_graph && \
28 bazel-bin/tensorflow/python/tools/freeze_graph \
29 --input_graph=some_graph_def.pb \
30 --input_checkpoint=model.ckpt-8361242 \
31 --output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
32 
33 You can also look at freeze_graph_test.py for an example of how to use it.
34 
35 """
36 from __future__ import absolute_import
37 from __future__ import division
38 from __future__ import print_function
39 
40 import argparse
41 import sys
42 
43 from google.protobuf import text_format
44 
45 from tensorflow.core.framework import graph_pb2
46 from tensorflow.core.protobuf import saver_pb2
47 from tensorflow.python import pywrap_tensorflow
48 from tensorflow.python.client import session
49 from tensorflow.python.framework import graph_util
50 from tensorflow.python.framework import importer
51 from tensorflow.python.platform import app
52 from tensorflow.python.platform import gfile
53 from tensorflow.python.training import saver as saver_lib
54 
55 FLAGS = None
56 
57 
58 def freeze_graph(input_graph,
59  input_saver,
60  input_binary,
61  input_checkpoint,
62  output_node_names,
63  restore_op_name,
64  filename_tensor_name,
65  output_graph,
66  clear_devices,
67  initializer_nodes,
68  variable_names_blacklist=""):
69  """Converts all variables in a graph and checkpoint into constants."""
70 
71  del restore_op_name, filename_tensor_name # Unused by updated loading code.
72 
73  if not gfile.Exists(input_graph):
74  print("Input graph file '" + input_graph + "' does not exist!")
75  return -1
76 
77  if input_saver and not gfile.Exists(input_saver):
78  print("Input saver file '" + input_saver + "' does not exist!")
79  return -1
80 
81  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
82  if not saver_lib.checkpoint_exists(input_checkpoint):
83  print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
84  return -1
85 
86  if not output_node_names:
87  print("You need to supply the name of a node to --output_node_names.")
88  return -1
89 
90  input_graph_def = graph_pb2.GraphDef()
91  mode = "rb" if input_binary else "r"
92  with gfile.FastGFile(input_graph, mode) as f:
93  if input_binary:
94  input_graph_def.ParseFromString(f.read())
95  else:
96  text_format.Merge(f.read(), input_graph_def)
97  # Remove all the explicit device specifications for this node. This helps to
98  # make the graph more portable.
99  if clear_devices:
100  for node in input_graph_def.node:
101  node.device = ""
102 
103  _ = importer.import_graph_def(input_graph_def, name="")
104 
105  with session.Session() as sess:
106  if input_saver:
107  with gfile.FastGFile(input_saver, mode) as f:
108  saver_def = saver_pb2.SaverDef()
109  if input_binary:
110  saver_def.ParseFromString(f.read())
111  else:
112  text_format.Merge(f.read(), saver_def)
113  saver = saver_lib.Saver(saver_def=saver_def)
114  saver.restore(sess, input_checkpoint)
115  else:
116  var_list = {}
117  reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
118  var_to_shape_map = reader.get_variable_to_shape_map()
119  for key in var_to_shape_map:
120  try:
121  tensor = sess.graph.get_tensor_by_name(key + ":0")
122  except KeyError:
123  # This tensor doesn't exist in the graph (for example it's
124  # 'global_step' or a similar housekeeping element) so skip it.
125  continue
126  var_list[key] = tensor
127  saver = saver_lib.Saver(var_list=var_list)
128  saver.restore(sess, input_checkpoint)
129  if initializer_nodes:
130  sess.run(initializer_nodes)
131 
132  variable_names_blacklist = (variable_names_blacklist.split(",") if
133  variable_names_blacklist else None)
134  output_graph_def = graph_util.convert_variables_to_constants(
135  sess,
136  input_graph_def,
137  output_node_names.split(","),
138  variable_names_blacklist=variable_names_blacklist)
139 
140  with gfile.GFile(output_graph, "wb") as f:
141  f.write(output_graph_def.SerializeToString())
142  print("%d ops in the final graph." % len(output_graph_def.node))
143 
144 
145 def main(unused_args):
146  freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
147  FLAGS.input_checkpoint, FLAGS.output_node_names,
148  FLAGS.restore_op_name, FLAGS.filename_tensor_name,
149  FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
150  FLAGS.variable_names_blacklist)
151 
152 
153 if __name__ == "__main__":
154  parser = argparse.ArgumentParser()
155  parser.register("type", "bool", lambda v: v.lower() == "true")
156  parser.add_argument(
157  "--input_graph",
158  type=str,
159  default="",
160  help="TensorFlow \'GraphDef\' file to load.")
161  parser.add_argument(
162  "--input_saver",
163  type=str,
164  default="",
165  help="TensorFlow saver file to load.")
166  parser.add_argument(
167  "--input_checkpoint",
168  type=str,
169  default="",
170  help="TensorFlow variables file to load.")
171  parser.add_argument(
172  "--output_graph",
173  type=str,
174  default="",
175  help="Output \'GraphDef\' file name.")
176  parser.add_argument(
177  "--input_binary",
178  nargs="?",
179  const=True,
180  type="bool",
181  default=False,
182  help="Whether the input files are in binary format.")
183  parser.add_argument(
184  "--output_node_names",
185  type=str,
186  default="",
187  help="The name of the output nodes, comma separated.")
188  parser.add_argument(
189  "--restore_op_name",
190  type=str,
191  default="save/restore_all",
192  help="The name of the master restore operator.")
193  parser.add_argument(
194  "--filename_tensor_name",
195  type=str,
196  default="save/Const:0",
197  help="The name of the tensor holding the save path.")
198  parser.add_argument(
199  "--clear_devices",
200  nargs="?",
201  const=True,
202  type="bool",
203  default=True,
204  help="Whether to remove device specifications.")
205  parser.add_argument(
206  "--initializer_nodes",
207  type=str,
208  default="",
209  help="comma separated list of initializer nodes to run before freezing.")
210  parser.add_argument(
211  "--variable_names_blacklist",
212  type=str,
213  default="",
214  help="""\
215  comma separated list of variables to skip converting to constants\
216  """)
217  FLAGS, unparsed = parser.parse_known_args()
218  app.run(main=main, argv=[sys.argv[0]] + unparsed)
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist="")
Definition: freeze_graph.py:68