Commit 7c76e763 authored by Yuxin Wu's avatar Yuxin Wu

fix serving export (fix #1449)

parent 315eab9d
...@@ -150,7 +150,7 @@ if __name__ == '__main__': ...@@ -150,7 +150,7 @@ if __name__ == '__main__':
if args.output_pb: if args.output_pb:
ModelExporter(predcfg).export_compact(args.output_pb, optimize=False) ModelExporter(predcfg).export_compact(args.output_pb, optimize=False)
elif args.output_serving: elif args.output_serving:
ModelExporter(predcfg).export_serving(args.output_serving, optimize=False) ModelExporter(predcfg).export_serving(args.output_serving)
if args.predict: if args.predict:
predictor = OfflinePredictor(predcfg) predictor = OfflinePredictor(predcfg)
......
...@@ -97,7 +97,7 @@ class ModelExporter(object): ...@@ -97,7 +97,7 @@ class ModelExporter(object):
Args: Args:
filename (str): path for export directory filename (str): path for export directory
tags (tuple): tuple of user specified tags. Defaults to "SERVING". tags (tuple): tuple of user specified tags. Defaults to just "SERVING".
signature_name (str): name of signature for prediction signature_name (str): name of signature for prediction
Note: Note:
...@@ -115,7 +115,7 @@ class ModelExporter(object): ...@@ -115,7 +115,7 @@ class ModelExporter(object):
""" """
if tags is None: if tags is None:
tags = (tf.saved_model.SERVING if get_tf_version_tuple() >= (1, 12) tags = (tf.saved_model.SERVING if get_tf_version_tuple() >= (1, 12)
else tf.saved_model.tag_constants.SERVING) else tf.saved_model.tag_constants.SERVING, )
self.graph = self.config._maybe_create_graph() self.graph = self.config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment