68 variable_names_blacklist=
""):
69 """Converts all variables in a graph and checkpoint into constants.""" 71 del restore_op_name, filename_tensor_name
73 if not gfile.Exists(input_graph):
74 print(
"Input graph file '" + input_graph +
"' does not exist!")
77 if input_saver
and not gfile.Exists(input_saver):
78 print(
"Input saver file '" + input_saver +
"' does not exist!")
82 if not saver_lib.checkpoint_exists(input_checkpoint):
83 print(
"Input checkpoint '" + input_checkpoint +
"' doesn't exist!")
86 if not output_node_names:
87 print(
"You need to supply the name of a node to --output_node_names.")
90 input_graph_def = graph_pb2.GraphDef()
91 mode =
"rb" if input_binary
else "r" 92 with gfile.FastGFile(input_graph, mode) as f:
94 input_graph_def.ParseFromString(f.read())
96 text_format.Merge(f.read(), input_graph_def)
100 for node
in input_graph_def.node:
103 _ = importer.import_graph_def(input_graph_def, name=
"")
105 with session.Session()
as sess:
107 with gfile.FastGFile(input_saver, mode)
as f:
108 saver_def = saver_pb2.SaverDef()
110 saver_def.ParseFromString(f.read())
112 text_format.Merge(f.read(), saver_def)
113 saver = saver_lib.Saver(saver_def=saver_def)
114 saver.restore(sess, input_checkpoint)
117 reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
118 var_to_shape_map = reader.get_variable_to_shape_map()
119 for key
in var_to_shape_map:
121 tensor = sess.graph.get_tensor_by_name(key +
":0")
126 var_list[key] = tensor
127 saver = saver_lib.Saver(var_list=var_list)
128 saver.restore(sess, input_checkpoint)
129 if initializer_nodes:
130 sess.run(initializer_nodes)
132 variable_names_blacklist = (variable_names_blacklist.split(
",")
if 133 variable_names_blacklist
else None)
134 output_graph_def = graph_util.convert_variables_to_constants(
137 output_node_names.split(
","),
138 variable_names_blacklist=variable_names_blacklist)
140 with gfile.GFile(output_graph,
"wb")
as f:
141 f.write(output_graph_def.SerializeToString())
142 print(
"%d ops in the final graph." % len(output_graph_def.node))