Commit 860f7a38 authored by Yuxin Wu's avatar Yuxin Wu

Add "toco_compatible" option for `export_compact`. (#1029)

parent c44b65fc
...@@ -13,7 +13,7 @@ from tensorflow.python.platform import gfile ...@@ -13,7 +13,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib from tensorflow.python.tools import optimize_for_inference_lib
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names, get_tf_version_tuple
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from ..utils import logger from ..utils import logger
...@@ -34,11 +34,15 @@ class ModelExporter(object): ...@@ -34,11 +34,15 @@ class ModelExporter(object):
super(ModelExporter, self).__init__() super(ModelExporter, self).__init__()
self.config = config self.config = config
def export_compact(self, filename): def export_compact(self, filename, toco_compatible=False):
"""Create a self-contained inference-only graph and write final graph (in pb format) to disk. """Create a self-contained inference-only graph and write final graph (in pb format) to disk.
Args: Args:
filename (str): path to the output graph filename (str): path to the output graph
toco_compatible (bool): See TensorFlow's
`optimize_for_inference
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_
for details. Only available after TF 1.8.
""" """
self.graph = self.config._maybe_create_graph() self.graph = self.config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
...@@ -66,12 +70,13 @@ class ModelExporter(object): ...@@ -66,12 +70,13 @@ class ModelExporter(object):
variable_names_blacklist=None) variable_names_blacklist=None)
# prune unused nodes from graph # prune unused nodes from graph
toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, )
pruned_graph_def = optimize_for_inference_lib.optimize_for_inference( pruned_graph_def = optimize_for_inference_lib.optimize_for_inference(
frozen_graph_def, frozen_graph_def,
[n.name[:-2] for n in input_tensors], [n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors], [n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes], [dtype.as_datatype_enum for dtype in dtypes],
False) *toco_args)
with gfile.FastGFile(filename, "wb") as f: with gfile.FastGFile(filename, "wb") as f:
f.write(pruned_graph_def.SerializeToString()) f.write(pruned_graph_def.SerializeToString())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment