Commit d3f11e3f authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] multi-GPU validation

parent 9dba9893
......@@ -11,7 +11,13 @@ __all__ = ['config', 'finalize_configs']
class AttrDict():
_freezed = False
""" Avoid accidental creation of new hierarchies. """
def __getattr__(self, name):
if self._freezed:
raise AttributeError(name)
ret = AttrDict()
setattr(self, name, ret)
return ret
......@@ -24,7 +30,7 @@ class AttrDict():
def to_dict(self):
"""Convert to a nested dict. """
return {k: v.to_dict() if isinstance(v, AttrDict) else v
for k, v in self.__dict__.items()}
for k, v in self.__dict__.items() if not k.startswith('_')}
def update_args(self, args):
"""Update from command line args. """
......@@ -43,6 +49,9 @@ class AttrDict():
v = eval(v)
setattr(dic, key, v)
def freeze(self):
self._freezed = True
# avoid silent bugs
def __eq__(self, _):
raise NotImplementedError()
......@@ -94,6 +103,7 @@ _C.TRAIN.STEPS_PER_EPOCH = 500
# Otherwise the actual steps to decrease learning rate are computed from the schedule.
# LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
_C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
_C.TRAIN.NUM_EVALS = 20 # number of evaluations to run during training
# preprocessing --------------------
# Alternative old (worse & faster) setting: 600, 1024
......@@ -208,4 +218,5 @@ def finalize_configs(is_training):
# autotune is too slow for inference
os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
_C.freeze()
logger.info("Config: ------------------------------------------\n" + str(_C))
......@@ -9,7 +9,7 @@ import itertools
from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import (
imgaug, TestDataSpeed,
PrefetchDataZMQ, MultiProcessMapDataZMQ, MultiThreadMapData,
MultiProcessMapDataZMQ, MultiThreadMapData,
MapDataComponent, DataFromList)
from tensorpack.utils import logger
# import tensorpack.utils.viz as tpviz
......@@ -381,18 +381,25 @@ def get_train_dataflow():
return ds
def get_eval_dataflow():
def get_eval_dataflow(shard=0, num_shards=1):
"""
Args:
shard, num_shards: to get subset of evaluation data
"""
imgs = COCODetection.load_many(cfg.DATA.BASEDIR, cfg.DATA.VAL, add_gt=False)
num_imgs = len(imgs)
img_per_shard = num_imgs // num_shards
img_range = (shard * img_per_shard, (shard + 1) * img_per_shard if shard + 1 < num_shards else num_imgs)
# no filter for training
ds = DataFromListOfDict(imgs, ['file_name', 'id'])
ds = DataFromListOfDict(imgs[img_range[0]: img_range[1]], ['file_name', 'id'])
def f(fname):
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
return im
ds = MapDataComponent(ds, f, 0)
if cfg.TRAINER != 'horovod':
ds = PrefetchDataZMQ(ds, 1)
# Evaluation itself may be multi-threaded, therefore don't add prefetch here.
return ds
......
......@@ -4,6 +4,7 @@
import tqdm
import os
from collections import namedtuple
from contextlib import ExitStack
import numpy as np
import cv2
......@@ -90,18 +91,24 @@ def detect_one_image(img, model_func):
return results
def eval_coco(df, detect_func):
def eval_coco(df, detect_func, tqdm_bar=None):
"""
Args:
df: a DataFlow which produces (image, image_id)
detect_func: a callable, takes [image] and returns [DetectionResult]
tqdm_bar: a tqdm object to be shared among multiple evaluation instances. If None,
will create a new one.
Returns:
list of dict, to be dumped to COCO json format
"""
df.reset_state()
all_results = []
with tqdm.tqdm(total=df.size(), **get_tqdm_kwargs()) as pbar:
# tqdm is not quite thread-safe: https://github.com/tqdm/tqdm/issues/323
with ExitStack() as stack:
if tqdm_bar is None:
tqdm_bar = stack.enter_context(
tqdm.tqdm(total=df.size(), **get_tqdm_kwargs()))
for img, img_id in df.get_data():
results = detect_func(img)
for r in results:
......@@ -124,7 +131,7 @@ def eval_coco(df, detect_func):
rle['counts'] = rle['counts'].decode('ascii')
res['segmentation'] = rle
all_results.append(res)
pbar.update(1)
tqdm_bar.update(1)
return all_results
......
......@@ -12,6 +12,7 @@ import numpy as np
import json
import six
import tensorflow as tf
from concurrent.futures import ThreadPoolExecutor
try:
import horovod.tensorflow as hvd
except ImportError:
......@@ -466,33 +467,54 @@ def predict(pred_func, input_file):
class EvalCallback(Callback):
"""
A callback that runs COCO evaluation once a while.
It supports multi-GPU evaluation if TRAINER=='replicated' and single-GPU evaluation if TRAINER=='horovod'
"""
def __init__(self, in_names, out_names):
self._in_names, self._out_names = in_names, out_names
def _setup_graph(self):
self.pred = self.trainer.get_predictor(self._in_names, self._out_names)
self.df = get_eval_dataflow()
num_gpu = cfg.TRAIN.NUM_GPUS
# Use two predictor threads per GPU to get better throughput
self.num_predictor = 1 if cfg.TRAINER == 'horovod' else num_gpu * 2
self.predictors = [self._build_coco_predictor(k % num_gpu) for k in range(self.num_predictor)]
self.dataflows = [get_eval_dataflow(shard=k, num_shards=self.num_predictor)
for k in range(self.num_predictor)]
def _build_coco_predictor(self, idx):
graph_func = self.trainer.get_predictor(self._in_names, self._out_names, device=idx)
return lambda img: detect_one_image(img, graph_func)
def _before_train(self):
EVAL_TIMES = 5 # eval 5 times during training
interval = self.trainer.max_epoch // (EVAL_TIMES + 1)
self.epochs_to_eval = set([interval * k for k in range(1, EVAL_TIMES + 1)])
num_eval = cfg.TRAIN.NUM_EVALS
interval = max(self.trainer.max_epoch // (num_eval + 1), 1)
self.epochs_to_eval = set([interval * k for k in range(1, num_eval + 1)])
self.epochs_to_eval.add(self.trainer.max_epoch)
logger.info("[EvalCallback] Will evaluate at epoch " + str(sorted(self.epochs_to_eval)))
if len(self.epochs_to_eval) < 15:
logger.info("[EvalCallback] Will evaluate at epoch " + str(sorted(self.epochs_to_eval)))
else:
logger.info("[EvalCallback] Will evaluate every {} epochs".format(interval))
def _eval(self):
all_results = eval_coco(self.df, lambda img: detect_one_image(img, self.pred))
with ThreadPoolExecutor(max_workers=self.num_predictor, thread_name_prefix='EvalWorker') as executor, \
tqdm.tqdm(total=sum([df.size() for df in self.dataflows])) as pbar:
futures = []
for dataflow, pred in zip(self.dataflows, self.predictors):
futures.append(executor.submit(eval_coco, dataflow, pred, pbar))
all_results = list(itertools.chain(*[fut.result() for fut in futures]))
output_file = os.path.join(
logger.get_logger_dir(), 'outputs{}.json'.format(self.global_step))
with open(output_file, 'w') as f:
json.dump(all_results, f)
try:
scores = print_evaluation_scores(output_file)
for k, v in scores.items():
self.trainer.monitors.put_scalar(k, v)
except Exception:
logger.exception("Exception in COCO evaluation.")
scores = {}
for k, v in scores.items():
self.trainer.monitors.put_scalar(k, v)
def _trigger_epoch(self):
if self.epoch_num in self.epochs_to_eval:
......@@ -558,7 +580,7 @@ if __name__ == '__main__':
init_lr = cfg.TRAIN.BASE_LR * 0.33 * (8. / cfg.TRAIN.NUM_GPUS)
warmup_schedule = [(0, init_lr), (cfg.TRAIN.WARMUP, cfg.TRAIN.BASE_LR)]
warmup_end_epoch = cfg.TRAIN.WARMUP * 1. / stepnum
lr_schedule = [(int(np.ceil(warmup_end_epoch)), warmup_schedule[-1][1])]
lr_schedule = [(int(np.ceil(warmup_end_epoch)), cfg.TRAIN.BASE_LR)]
factor = 8. / cfg.TRAIN.NUM_GPUS
for idx, steps in enumerate(cfg.TRAIN.LR_SCHEDULE[:-1]):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment