Commit 1cdd2e9e authored by HJYOO's avatar HJYOO Committed by Yuxin Wu

Feat. extract graph (#1284)

parent e63d8b7e
...@@ -15,6 +15,7 @@ assert six.PY3, "This example requires Python 3!" ...@@ -15,6 +15,7 @@ assert six.PY3, "This example requires Python 3!"
import tensorpack.utils.viz as tpviz import tensorpack.utils.viz as tpviz
from tensorpack.predict import MultiTowerOfflinePredictor, OfflinePredictor, PredictConfig from tensorpack.predict import MultiTowerOfflinePredictor, OfflinePredictor, PredictConfig
from tensorpack.tfutils import get_model_loader, get_tf_version_tuple from tensorpack.tfutils import get_model_loader, get_tf_version_tuple
from tensorpack.tfutils.export import ModelExporter
from tensorpack.utils import fs, logger from tensorpack.utils import fs, logger
from dataset import DatasetRegistry, register_coco from dataset import DatasetRegistry, register_coco
...@@ -114,6 +115,8 @@ if __name__ == '__main__': ...@@ -114,6 +115,8 @@ if __name__ == '__main__':
parser.add_argument('--benchmark', action='store_true', help="Benchmark the speed of the model + postprocessing") parser.add_argument('--benchmark', action='store_true', help="Benchmark the speed of the model + postprocessing")
parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py", parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py",
nargs='+') nargs='+')
parser.add_argument('--compact', help='if you want to save a model to .pb')
parser.add_argument('--serving', help='if you want to save a model to serving file')
args = parser.parse_args() args = parser.parse_args()
if args.config: if args.config:
...@@ -139,6 +142,12 @@ if __name__ == '__main__': ...@@ -139,6 +142,12 @@ if __name__ == '__main__':
session_init=get_model_loader(args.load), session_init=get_model_loader(args.load),
input_names=MODEL.get_inference_tensor_names()[0], input_names=MODEL.get_inference_tensor_names()[0],
output_names=MODEL.get_inference_tensor_names()[1]) output_names=MODEL.get_inference_tensor_names()[1])
if args.compact:
ModelExporter(predcfg).export_compact(args.compact, optimize=False)
elif args.serving:
ModelExporter(predcfg).export_serving(args.serving, optimize=False)
if args.predict: if args.predict:
predictor = OfflinePredictor(predcfg) predictor = OfflinePredictor(predcfg)
for image_file in args.predict: for image_file in args.predict:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment