Commit 88e900a9 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] more than one cfg.DATA.VAL

parent d6393ea3
...@@ -108,7 +108,7 @@ Performance in [Detectron](https://github.com/facebookresearch/Detectron/) can b ...@@ -108,7 +108,7 @@ Performance in [Detectron](https://github.com/facebookresearch/Detectron/) can b
[R101FPN9xGNCasAugScratch]: http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-ScratchGN.npz [R101FPN9xGNCasAugScratch]: http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-ScratchGN.npz
<a id="ft1">1</a>: Numbers taken from [Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md). <a id="ft1">1</a>: Numbers taken from [Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md).
We comapre models that have identical training & inference cost between the two implementation. However their numbers can be different due to many small implementation details. We compare models that have identical training & inference cost between the two implementations. However their numbers can be different due to many small implementation details.
For example, our FPN models are sometimes slightly worse in box AP, which is probably due to batch size. For example, our FPN models are sometimes slightly worse in box AP, which is probably due to batch size.
<a id="ft2">2</a>: Numbers taken from Table 5 in [Group Normalization](https://arxiv.org/abs/1803.08494) <a id="ft2">2</a>: Numbers taken from Table 5 in [Group Normalization](https://arxiv.org/abs/1803.08494)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import numpy as np import numpy as np
import os import os
import six
import pprint import pprint
from tensorpack.utils import logger from tensorpack.utils import logger
...@@ -80,9 +81,10 @@ _C.MODE_FPN = False ...@@ -80,9 +81,10 @@ _C.MODE_FPN = False
# dataset ----------------------- # dataset -----------------------
_C.DATA.BASEDIR = '/path/to/your/COCO/DIR' _C.DATA.BASEDIR = '/path/to/your/COCO/DIR'
# All TRAIN dataset will be concatenated for training.
_C.DATA.TRAIN = ['train2014', 'valminusminival2014'] # i.e. trainval35k, AKA train2017 _C.DATA.TRAIN = ['train2014', 'valminusminival2014'] # i.e. trainval35k, AKA train2017
# For now, only support evaluation on single dataset # Each VAL dataset will be evaluated separately (instead of concatenated)
_C.DATA.VAL = 'minival2014' # AKA val2017 _C.DATA.VAL = ('minival2014', ) # AKA val2017
_C.DATA.NUM_CATEGORY = 80 # 80 categories in COCO _C.DATA.NUM_CATEGORY = 80 # 80 categories in COCO
_C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG". _C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# For COCO, this list will be populated later by the COCO data loader. # For COCO, this list will be populated later by the COCO data loader.
...@@ -210,6 +212,8 @@ def finalize_configs(is_training): ...@@ -210,6 +212,8 @@ def finalize_configs(is_training):
_C.freeze(False) # populate new keys now _C.freeze(False) # populate new keys now
_C.DATA.NUM_CLASS = _C.DATA.NUM_CATEGORY + 1 # +1 background _C.DATA.NUM_CLASS = _C.DATA.NUM_CATEGORY + 1 # +1 background
_C.DATA.BASEDIR = os.path.expanduser(_C.DATA.BASEDIR) _C.DATA.BASEDIR = os.path.expanduser(_C.DATA.BASEDIR)
if isinstance(_C.DATA.VAL, six.string_types): # support single string (the typical case) as well
_C.DATA.VAL = (_C.DATA.VAL, )
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN', 'None'], _C.BACKBONE.NORM assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN', 'None'], _C.BACKBONE.NORM
if _C.BACKBONE.NORM != 'FreezeBN': if _C.BACKBONE.NORM != 'FreezeBN':
...@@ -246,6 +250,10 @@ def finalize_configs(is_training): ...@@ -246,6 +250,10 @@ def finalize_configs(is_training):
if _C.TRAINER == 'horovod': if _C.TRAINER == 'horovod':
import horovod.tensorflow as hvd import horovod.tensorflow as hvd
ngpu = hvd.size() ngpu = hvd.size()
if ngpu == hvd.local_size():
logger.warn("It's not recommended to use horovod for single-machine training. "
"Replicated trainer is more stable and has the same efficiency.")
else: else:
assert 'OMPI_COMM_WORLD_SIZE' not in os.environ assert 'OMPI_COMM_WORLD_SIZE' not in os.environ
ngpu = get_num_gpu() ngpu = get_num_gpu()
......
...@@ -381,12 +381,20 @@ def get_train_dataflow(): ...@@ -381,12 +381,20 @@ def get_train_dataflow():
return ds return ds
def get_eval_dataflow(shard=0, num_shards=1): def get_eval_dataflow(name, shard=0, num_shards=1):
""" """
Args: Args:
name (str): name of the dataset to evaluate
shard, num_shards: to get subset of evaluation data shard, num_shards: to get subset of evaluation data
""" """
roidbs = COCODetection.load_many(cfg.DATA.BASEDIR, cfg.DATA.VAL, add_gt=False) roidbs = COCODetection.load_many(cfg.DATA.BASEDIR, name, add_gt=False)
"""
To inference on your own data, change this to your loader.
Produce "roidbs" as a list of dict, in the dict the following keys are needed for training:
file_name: str, full path to the image
id: an id of this image
"""
num_imgs = len(roidbs) num_imgs = len(roidbs)
img_per_shard = num_imgs // num_shards img_per_shard = num_imgs // num_shards
img_range = (shard * img_per_shard, (shard + 1) * img_per_shard if shard + 1 < num_shards else num_imgs) img_range = (shard * img_per_shard, (shard + 1) * img_per_shard if shard + 1 < num_shards else num_imgs)
......
...@@ -161,12 +161,19 @@ def multithread_eval_coco(dataflows, detect_funcs): ...@@ -161,12 +161,19 @@ def multithread_eval_coco(dataflows, detect_funcs):
# https://github.com/pdollar/coco/blob/master/PythonAPI/pycocoEvalDemo.ipynb # https://github.com/pdollar/coco/blob/master/PythonAPI/pycocoEvalDemo.ipynb
def print_coco_metrics(json_file): def print_coco_metrics(dataset, json_file):
"""
Args:
dataset (str): name of the dataset
json_file (str): path to the results json file in coco format
If your data is not in COCO format, write your own evaluation function.
"""
ret = {} ret = {}
assert cfg.DATA.BASEDIR and os.path.isdir(cfg.DATA.BASEDIR) assert cfg.DATA.BASEDIR and os.path.isdir(cfg.DATA.BASEDIR)
annofile = os.path.join( annofile = os.path.join(
cfg.DATA.BASEDIR, 'annotations', cfg.DATA.BASEDIR, 'annotations',
'instances_{}.json'.format(cfg.DATA.VAL)) 'instances_{}.json'.format(dataset))
coco = COCO(annofile) coco = COCO(annofile)
cocoDt = coco.loadRes(json_file) cocoDt = coco.loadRes(json_file)
cocoEval = COCOeval(coco, cocoDt, 'bbox') cocoEval = COCOeval(coco, cocoDt, 'bbox')
......
...@@ -379,19 +379,24 @@ def offline_evaluate(pred_config, output_file): ...@@ -379,19 +379,24 @@ def offline_evaluate(pred_config, output_file):
num_gpu = cfg.TRAIN.NUM_GPUS num_gpu = cfg.TRAIN.NUM_GPUS
graph_funcs = MultiTowerOfflinePredictor( graph_funcs = MultiTowerOfflinePredictor(
pred_config, list(range(num_gpu))).get_predictors() pred_config, list(range(num_gpu))).get_predictors()
predictors = [] predictors = []
dataflows = []
for k in range(num_gpu): for k in range(num_gpu):
predictors.append(lambda img, predictors.append(lambda img,
pred=graph_funcs[k]: detect_one_image(img, pred)) pred=graph_funcs[k]: detect_one_image(img, pred))
dataflows.append(get_eval_dataflow(shard=k, num_shards=num_gpu)) for dataset in cfg.DATA.VAL:
if num_gpu > 1: logger.info("Evaluating {} ...".format(dataset))
all_results = multithread_eval_coco(dataflows, predictors) dataflows = [
else: get_eval_dataflow(dataset, shard=k, num_shards=num_gpu)
all_results = eval_coco(dataflows[0], predictors[0]) for k in range(num_gpu) ]
with open(output_file, 'w') as f: if num_gpu > 1:
json.dump(all_results, f) all_results = multithread_eval_coco(dataflows, predictors)
print_coco_metrics(output_file) else:
all_results = eval_coco(dataflows[0], predictors[0])
output = output_file + '-' + dataset
with open(output, 'w') as f:
json.dump(all_results, f)
print_coco_metrics(dataset, output)
def predict(pred_func, input_file): def predict(pred_func, input_file):
...@@ -412,7 +417,8 @@ class EvalCallback(Callback): ...@@ -412,7 +417,8 @@ class EvalCallback(Callback):
_chief_only = False _chief_only = False
def __init__(self, in_names, out_names): def __init__(self, eval_dataset, in_names, out_names):
self._eval_dataset = eval_dataset
self._in_names, self._out_names = in_names, out_names self._in_names, self._out_names = in_names, out_names
def _setup_graph(self): def _setup_graph(self):
...@@ -424,7 +430,8 @@ class EvalCallback(Callback): ...@@ -424,7 +430,8 @@ class EvalCallback(Callback):
# Use two predictor threads per GPU to get better throughput # Use two predictor threads per GPU to get better throughput
self.num_predictor = num_gpu if buggy_tf else num_gpu * 2 self.num_predictor = num_gpu if buggy_tf else num_gpu * 2
self.predictors = [self._build_coco_predictor(k % num_gpu) for k in range(self.num_predictor)] self.predictors = [self._build_coco_predictor(k % num_gpu) for k in range(self.num_predictor)]
self.dataflows = [get_eval_dataflow(shard=k, num_shards=self.num_predictor) self.dataflows = [get_eval_dataflow(self._eval_dataset,
shard=k, num_shards=self.num_predictor)
for k in range(self.num_predictor)] for k in range(self.num_predictor)]
else: else:
# Only eval on the first machine. # Only eval on the first machine.
...@@ -432,7 +439,8 @@ class EvalCallback(Callback): ...@@ -432,7 +439,8 @@ class EvalCallback(Callback):
self._horovod_run_eval = hvd.rank() == hvd.local_rank() self._horovod_run_eval = hvd.rank() == hvd.local_rank()
if self._horovod_run_eval: if self._horovod_run_eval:
self.predictor = self._build_coco_predictor(0) self.predictor = self._build_coco_predictor(0)
self.dataflow = get_eval_dataflow(shard=hvd.local_rank(), num_shards=hvd.local_size()) self.dataflow = get_eval_dataflow(self._eval_dataset,
shard=hvd.local_rank(), num_shards=hvd.local_size())
self.barrier = hvd.allreduce(tf.random_normal(shape=[1])) self.barrier = hvd.allreduce(tf.random_normal(shape=[1]))
...@@ -475,11 +483,11 @@ class EvalCallback(Callback): ...@@ -475,11 +483,11 @@ class EvalCallback(Callback):
os.unlink(fname) os.unlink(fname)
output_file = os.path.join( output_file = os.path.join(
logdir, 'outputs{}.json'.format(self.global_step)) logdir, '{}-outputs{}.json'.format(self._eval_dataset, self.global_step))
with open(output_file, 'w') as f: with open(output_file, 'w') as f:
json.dump(all_results, f) json.dump(all_results, f)
try: try:
scores = print_coco_metrics(output_file) scores = print_coco_metrics(self._eval_dataset, output_file)
for k, v in scores.items(): for k, v in scores.items():
self.trainer.monitors.put_scalar(k, v) self.trainer.monitors.put_scalar(k, v)
except Exception: except Exception:
...@@ -565,6 +573,7 @@ if __name__ == '__main__': ...@@ -565,6 +573,7 @@ if __name__ == '__main__':
total_passes = cfg.TRAIN.LR_SCHEDULE[-1] * 8 / train_dataflow.size() total_passes = cfg.TRAIN.LR_SCHEDULE[-1] * 8 / train_dataflow.size()
logger.info("Total passes of the training set is: {:.5g}".format(total_passes)) logger.info("Total passes of the training set is: {:.5g}".format(total_passes))
callbacks = [ callbacks = [
PeriodicCallback( PeriodicCallback(
ModelSaver(max_to_keep=10, keep_checkpoint_every_n_hours=1), ModelSaver(max_to_keep=10, keep_checkpoint_every_n_hours=1),
...@@ -573,10 +582,12 @@ if __name__ == '__main__': ...@@ -573,10 +582,12 @@ if __name__ == '__main__':
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
'learning_rate', warmup_schedule, interp='linear', step_based=True), 'learning_rate', warmup_schedule, interp='linear', step_based=True),
ScheduledHyperParamSetter('learning_rate', lr_schedule), ScheduledHyperParamSetter('learning_rate', lr_schedule),
EvalCallback(*MODEL.get_inference_tensor_names()),
PeakMemoryTracker(), PeakMemoryTracker(),
EstimatedTimeLeft(median=True), EstimatedTimeLeft(median=True),
SessionRunTimeout(60000).set_chief_only(True), # 1 minute timeout SessionRunTimeout(60000).set_chief_only(True), # 1 minute timeout
] + [
EvalCallback(dataset, *MODEL.get_inference_tensor_names())
for dataset in cfg.DATA.VAL
] ]
if not is_horovod: if not is_horovod:
callbacks.append(GPUUtilizationTracker()) callbacks.append(GPUUtilizationTracker())
......
...@@ -128,7 +128,7 @@ if __name__ == '__main__': ...@@ -128,7 +128,7 @@ if __name__ == '__main__':
model.data_format = args.data_format model.data_format = args.data_format
if args.eval: if args.eval:
batch = 128 # something that can run on one gpu batch = 128 # something that can run on one gpu
ds = get_data('val', batch) ds = get_imagenet_dataflow(args.data, 'val', batch)
eval_on_ILSVRC12(model, get_model_loader(args.load), ds) eval_on_ILSVRC12(model, get_model_loader(args.load), ds)
else: else:
if args.fake: if args.fake:
......
...@@ -82,7 +82,7 @@ class TowerTrainer(Trainer): ...@@ -82,7 +82,7 @@ class TowerTrainer(Trainer):
def get_predictor(self, input_names, output_names, device=0): def get_predictor(self, input_names, output_names, device=0):
""" """
This method will build the tower under ``TowerContext(is_training=False)``, This method will build the trainer's tower function under ``TowerContext(is_training=False)``,
and returns a callable predictor with input placeholders & output tensors in this tower. and returns a callable predictor with input placeholders & output tensors in this tower.
Args: Args:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment