Commit 30ead05b authored by Yuxin Wu's avatar Yuxin Wu

optimze_for_inference can fail (#1064)

parent 21c49469
...@@ -4,8 +4,11 @@ about: Report unexpected problems about Tensorpack or its examples. ...@@ -4,8 +4,11 @@ about: Report unexpected problems about Tensorpack or its examples.
--- ---
If you're asking about an unexpected problem you met, use this template. If you're asking about an unexpected problem which you do not know the root cause,
__PLEASE DO NOT DELETE THIS TEMPLATE, FILL IT__: use this template. __PLEASE DO NOT DELETE THIS TEMPLATE, FILL IT__:
If you already know the root cause to your problem,
feel free to delete everything in this template.
### 1. What you did: ### 1. What you did:
...@@ -54,5 +57,5 @@ not our responsibility to figure out. ...@@ -54,5 +57,5 @@ not our responsibility to figure out.
using an IDE or jupyter notebook), please retry under a normal command line shell. using an IDE or jupyter notebook), please retry under a normal command line shell.
+ Hardware information, e.g. number of GPUs used. + Hardware information, e.g. number of GPUs used.
Feel free to add extra information related to your issue, but You may often want to provide extra information related to your issue, but
please try to provide the above information __accurately__ to save effort in the investigation. at the minimum please try to provide the above information __accurately__ to save effort in the investigation.
...@@ -99,6 +99,10 @@ demonstrates the usage of such a frozen/pruned graph. ...@@ -99,6 +99,10 @@ demonstrates the usage of such a frozen/pruned graph.
Again, you may often want to use a different graph for inference and you can Again, you may often want to use a different graph for inference and you can
do so by the arguments of `PredictConfig`. do so by the arguments of `PredictConfig`.
Note that the exporter relies on TensorFlow's automatic graph transformation, which do not always work reliably.
Automated graph transformation is often suboptimal or sometimes fail.
It's safer to write the graph by yourself.
## Inference After Training: Do It Yourself ## Inference After Training: Do It Yourself
......
...@@ -34,16 +34,20 @@ class ModelExporter(object): ...@@ -34,16 +34,20 @@ class ModelExporter(object):
super(ModelExporter, self).__init__() super(ModelExporter, self).__init__()
self.config = config self.config = config
def export_compact(self, filename, toco_compatible=False): def export_compact(self, filename, optimize=True, toco_compatible=False):
"""Create a self-contained inference-only graph and write final graph (in pb format) to disk. """Create a self-contained inference-only graph and write final graph (in pb format) to disk.
Args: Args:
filename (str): path to the output graph filename (str): path to the output graph
optimize (bool): whether to use TensorFlow's `optimize_for_inference`
to prune and optimize the graph. This does not work on all types of graphs.
toco_compatible (bool): See TensorFlow's toco_compatible (bool): See TensorFlow's
`optimize_for_inference `optimize_for_inference
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_ <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_
for details. Only available after TF 1.8. for details. Only available after TF 1.8.
""" """
if toco_compatible:
assert optimize, "toco_compatible is only effective when optimize=True!"
self.graph = self.config._maybe_create_graph() self.graph = self.config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
input = PlaceholderInput() input = PlaceholderInput()
...@@ -70,8 +74,9 @@ class ModelExporter(object): ...@@ -70,8 +74,9 @@ class ModelExporter(object):
variable_names_blacklist=None) variable_names_blacklist=None)
# prune unused nodes from graph # prune unused nodes from graph
if optimize:
toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, ) toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, )
pruned_graph_def = optimize_for_inference_lib.optimize_for_inference( frozen_graph_def = optimize_for_inference_lib.optimize_for_inference(
frozen_graph_def, frozen_graph_def,
[n.name[:-2] for n in input_tensors], [n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors], [n.name[:-2] for n in output_tensors],
...@@ -79,7 +84,7 @@ class ModelExporter(object): ...@@ -79,7 +84,7 @@ class ModelExporter(object):
*toco_args) *toco_args)
with gfile.FastGFile(filename, "wb") as f: with gfile.FastGFile(filename, "wb") as f:
f.write(pruned_graph_def.SerializeToString()) f.write(frozen_graph_def.SerializeToString())
logger.info("Output graph written to {}.".format(filename)) logger.info("Output graph written to {}.".format(filename))
def export_serving(self, filename, def export_serving(self, filename,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment