Commit 1cdd2e9e authored by HJYOO's avatar HJYOO Committed by Yuxin Wu

Feat. extract graph (#1284)

parent e63d8b7e
......@@ -15,6 +15,7 @@ assert six.PY3, "This example requires Python 3!"
import tensorpack.utils.viz as tpviz
from tensorpack.predict import MultiTowerOfflinePredictor, OfflinePredictor, PredictConfig
from tensorpack.tfutils import get_model_loader, get_tf_version_tuple
from tensorpack.tfutils.export import ModelExporter
from tensorpack.utils import fs, logger
from dataset import DatasetRegistry, register_coco
......@@ -114,6 +115,8 @@ if __name__ == '__main__':
parser.add_argument('--benchmark', action='store_true', help="Benchmark the speed of the model + postprocessing")
parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py",
nargs='+')
parser.add_argument('--compact', help='if you want to save a model to .pb')
parser.add_argument('--serving', help='if you want to save a model to serving file')
args = parser.parse_args()
if args.config:
......@@ -139,6 +142,12 @@ if __name__ == '__main__':
session_init=get_model_loader(args.load),
input_names=MODEL.get_inference_tensor_names()[0],
output_names=MODEL.get_inference_tensor_names()[1])
if args.compact:
ModelExporter(predcfg).export_compact(args.compact, optimize=False)
elif args.serving:
ModelExporter(predcfg).export_serving(args.serving, optimize=False)
if args.predict:
predictor = OfflinePredictor(predcfg)
for image_file in args.predict:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment