Commit d451368a authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] inference speed test

parent 9c6e39c5
...@@ -89,7 +89,8 @@ class GeneralizedRCNN(ModelDesc): ...@@ -89,7 +89,8 @@ class GeneralizedRCNN(ModelDesc):
ns = G.get_name_scope() ns = G.get_name_scope()
for name in self.get_inference_tensor_names()[1]: for name in self.get_inference_tensor_names()[1]:
try: try:
G.get_tensor_by_name('/'.join([ns, name + ':0'])) name = '/'.join([ns, name]) if ns else name
G.get_tensor_by_name(name + ':0')
except KeyError: except KeyError:
raise KeyError("Your model does not define the tensor '{}' in inference context.".format(name)) raise KeyError("Your model does not define the tensor '{}' in inference context.".format(name))
......
...@@ -111,6 +111,7 @@ if __name__ == '__main__': ...@@ -111,6 +111,7 @@ if __name__ == '__main__':
"This argument is the path to the output json evaluation file") "This argument is the path to the output json evaluation file")
parser.add_argument('--predict', help="Run prediction on a given image. " parser.add_argument('--predict', help="Run prediction on a given image. "
"This argument is the path to the input image file", nargs='+') "This argument is the path to the input image file", nargs='+')
parser.add_argument('--benchmark', action='store_true', help="Benchmark the speed of the model + postprocessing")
parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py", parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py",
nargs='+') nargs='+')
...@@ -145,3 +146,11 @@ if __name__ == '__main__': ...@@ -145,3 +146,11 @@ if __name__ == '__main__':
elif args.evaluate: elif args.evaluate:
assert args.evaluate.endswith('.json'), args.evaluate assert args.evaluate.endswith('.json'), args.evaluate
do_evaluate(predcfg, args.evaluate) do_evaluate(predcfg, args.evaluate)
elif args.benchmark:
df = get_eval_dataflow(cfg.DATA.VAL[0])
df.reset_state()
predictor = OfflinePredictor(predcfg)
for img in tqdm.tqdm(df, total=len(df)):
# This include post-processing time, which is done on CPU and not optimized
# To exclude it, modify `predict_image`.
predict_image(img[0], predictor)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment