15 """Converts checkpoint variables into Const ops in a standalone GraphDef file. 17 This script is designed to take a GraphDef proto, a SaverDef proto, and a set of 18 variable values stored in a checkpoint file, and output a GraphDef with all of 19 the variable ops converted into const ops containing the values of the 22 It's useful to do this when we need to load a single file in C++, especially in 23 environments like mobile or embedded where we may not have access to the 24 RestoreTensor ops and file loading calls that they rely on. 26 An example of command-line usage is: 27 bazel build tensorflow/python/tools:freeze_graph && \ 28 bazel-bin/tensorflow/python/tools/freeze_graph \ 29 --input_graph=some_graph_def.pb \ 30 --input_checkpoint=model.ckpt-8361242 \ 31 --output_graph=/tmp/frozen_graph.pb --output_node_names=softmax 33 You can also look at freeze_graph_test.py for an example of how to use it. 36 from __future__
import absolute_import
37 from __future__
import division
38 from __future__
import print_function
43 from google.protobuf
import text_format
45 from tensorflow.core.framework
import graph_pb2
46 from tensorflow.core.protobuf
import saver_pb2
47 from tensorflow.python
import pywrap_tensorflow
48 from tensorflow.python.client
import session
49 from tensorflow.python.framework
import graph_util
50 from tensorflow.python.framework
import importer
51 from tensorflow.python.platform
import app
52 from tensorflow.python.platform
import gfile
53 from tensorflow.python.training
import saver
as saver_lib
68 variable_names_blacklist=
""):
69 """Converts all variables in a graph and checkpoint into constants.""" 71 del restore_op_name, filename_tensor_name
73 if not gfile.Exists(input_graph):
74 print(
"Input graph file '" + input_graph +
"' does not exist!")
77 if input_saver
and not gfile.Exists(input_saver):
78 print(
"Input saver file '" + input_saver +
"' does not exist!")
82 if not saver_lib.checkpoint_exists(input_checkpoint):
83 print(
"Input checkpoint '" + input_checkpoint +
"' doesn't exist!")
86 if not output_node_names:
87 print(
"You need to supply the name of a node to --output_node_names.")
90 input_graph_def = graph_pb2.GraphDef()
91 mode =
"rb" if input_binary
else "r" 92 with gfile.FastGFile(input_graph, mode) as f:
94 input_graph_def.ParseFromString(f.read())
96 text_format.Merge(f.read(), input_graph_def)
100 for node
in input_graph_def.node:
103 _ = importer.import_graph_def(input_graph_def, name=
"")
105 with session.Session()
as sess:
107 with gfile.FastGFile(input_saver, mode)
as f:
108 saver_def = saver_pb2.SaverDef()
110 saver_def.ParseFromString(f.read())
112 text_format.Merge(f.read(), saver_def)
113 saver = saver_lib.Saver(saver_def=saver_def)
114 saver.restore(sess, input_checkpoint)
117 reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
118 var_to_shape_map = reader.get_variable_to_shape_map()
119 for key
in var_to_shape_map:
121 tensor = sess.graph.get_tensor_by_name(key +
":0")
126 var_list[key] = tensor
127 saver = saver_lib.Saver(var_list=var_list)
128 saver.restore(sess, input_checkpoint)
129 if initializer_nodes:
130 sess.run(initializer_nodes)
132 variable_names_blacklist = (variable_names_blacklist.split(
",")
if 133 variable_names_blacklist
else None)
134 output_graph_def = graph_util.convert_variables_to_constants(
137 output_node_names.split(
","),
138 variable_names_blacklist=variable_names_blacklist)
140 with gfile.GFile(output_graph,
"wb")
as f:
141 f.write(output_graph_def.SerializeToString())
142 print(
"%d ops in the final graph." % len(output_graph_def.node))
146 freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
147 FLAGS.input_checkpoint, FLAGS.output_node_names,
148 FLAGS.restore_op_name, FLAGS.filename_tensor_name,
149 FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
150 FLAGS.variable_names_blacklist)
153 if __name__ ==
"__main__":
154 parser = argparse.ArgumentParser()
155 parser.register(
"type",
"bool",
lambda v: v.lower() ==
"true")
160 help=
"TensorFlow \'GraphDef\' file to load.")
165 help=
"TensorFlow saver file to load.")
167 "--input_checkpoint",
170 help=
"TensorFlow variables file to load.")
175 help=
"Output \'GraphDef\' file name.")
182 help=
"Whether the input files are in binary format.")
184 "--output_node_names",
187 help=
"The name of the output nodes, comma separated.")
191 default=
"save/restore_all",
192 help=
"The name of the master restore operator.")
194 "--filename_tensor_name",
196 default=
"save/Const:0",
197 help=
"The name of the tensor holding the save path.")
204 help=
"Whether to remove device specifications.")
206 "--initializer_nodes",
209 help=
"comma separated list of initializer nodes to run before freezing.")
211 "--variable_names_blacklist",
215 comma separated list of variables to skip converting to constants\ 217 FLAGS, unparsed = parser.parse_known_args()
218 app.run(main=main, argv=[sys.argv[0]] + unparsed)
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist="")