Commit 860f7a38 authored by Yuxin Wu's avatar Yuxin Wu

Add "toco_compatible" option for `export_compact`. (#1029)

parent c44b65fc
......@@ -13,7 +13,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib
from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names
from ..tfutils.common import get_tensors_by_names, get_tf_version_tuple
from ..tfutils.tower import PredictTowerContext
from ..utils import logger
......@@ -34,11 +34,15 @@ class ModelExporter(object):
super(ModelExporter, self).__init__()
self.config = config
def export_compact(self, filename):
def export_compact(self, filename, toco_compatible=False):
"""Create a self-contained inference-only graph and write final graph (in pb format) to disk.
Args:
filename (str): path to the output graph
toco_compatible (bool): See TensorFlow's
`optimize_for_inference
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_
for details. Only available after TF 1.8.
"""
self.graph = self.config._maybe_create_graph()
with self.graph.as_default():
......@@ -66,12 +70,13 @@ class ModelExporter(object):
variable_names_blacklist=None)
# prune unused nodes from graph
toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, )
pruned_graph_def = optimize_for_inference_lib.optimize_for_inference(
frozen_graph_def,
[n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes],
False)
*toco_args)
with gfile.FastGFile(filename, "wb") as f:
f.write(pruned_graph_def.SerializeToString())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment