Commit d3f11e3f authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] multi-GPU validation

parent 9dba9893
...@@ -11,7 +11,13 @@ __all__ = ['config', 'finalize_configs'] ...@@ -11,7 +11,13 @@ __all__ = ['config', 'finalize_configs']
class AttrDict(): class AttrDict():
_freezed = False
""" Avoid accidental creation of new hierarchies. """
def __getattr__(self, name): def __getattr__(self, name):
if self._freezed:
raise AttributeError(name)
ret = AttrDict() ret = AttrDict()
setattr(self, name, ret) setattr(self, name, ret)
return ret return ret
...@@ -24,7 +30,7 @@ class AttrDict(): ...@@ -24,7 +30,7 @@ class AttrDict():
def to_dict(self): def to_dict(self):
"""Convert to a nested dict. """ """Convert to a nested dict. """
return {k: v.to_dict() if isinstance(v, AttrDict) else v return {k: v.to_dict() if isinstance(v, AttrDict) else v
for k, v in self.__dict__.items()} for k, v in self.__dict__.items() if not k.startswith('_')}
def update_args(self, args): def update_args(self, args):
"""Update from command line args. """ """Update from command line args. """
...@@ -43,6 +49,9 @@ class AttrDict(): ...@@ -43,6 +49,9 @@ class AttrDict():
v = eval(v) v = eval(v)
setattr(dic, key, v) setattr(dic, key, v)
def freeze(self):
self._freezed = True
# avoid silent bugs # avoid silent bugs
def __eq__(self, _): def __eq__(self, _):
raise NotImplementedError() raise NotImplementedError()
...@@ -94,6 +103,7 @@ _C.TRAIN.STEPS_PER_EPOCH = 500 ...@@ -94,6 +103,7 @@ _C.TRAIN.STEPS_PER_EPOCH = 500
# Otherwise the actual steps to decrease learning rate are computed from the schedule. # Otherwise the actual steps to decrease learning rate are computed from the schedule.
# LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron # LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
_C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron _C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
_C.TRAIN.NUM_EVALS = 20 # number of evaluations to run during training
# preprocessing -------------------- # preprocessing --------------------
# Alternative old (worse & faster) setting: 600, 1024 # Alternative old (worse & faster) setting: 600, 1024
...@@ -208,4 +218,5 @@ def finalize_configs(is_training): ...@@ -208,4 +218,5 @@ def finalize_configs(is_training):
# autotune is too slow for inference # autotune is too slow for inference
os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0' os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
_C.freeze()
logger.info("Config: ------------------------------------------\n" + str(_C)) logger.info("Config: ------------------------------------------\n" + str(_C))
...@@ -9,7 +9,7 @@ import itertools ...@@ -9,7 +9,7 @@ import itertools
from tensorpack.utils.argtools import memoized, log_once from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import ( from tensorpack.dataflow import (
imgaug, TestDataSpeed, imgaug, TestDataSpeed,
PrefetchDataZMQ, MultiProcessMapDataZMQ, MultiThreadMapData, MultiProcessMapDataZMQ, MultiThreadMapData,
MapDataComponent, DataFromList) MapDataComponent, DataFromList)
from tensorpack.utils import logger from tensorpack.utils import logger
# import tensorpack.utils.viz as tpviz # import tensorpack.utils.viz as tpviz
...@@ -381,18 +381,25 @@ def get_train_dataflow(): ...@@ -381,18 +381,25 @@ def get_train_dataflow():
return ds return ds
def get_eval_dataflow(): def get_eval_dataflow(shard=0, num_shards=1):
"""
Args:
shard, num_shards: to get subset of evaluation data
"""
imgs = COCODetection.load_many(cfg.DATA.BASEDIR, cfg.DATA.VAL, add_gt=False) imgs = COCODetection.load_many(cfg.DATA.BASEDIR, cfg.DATA.VAL, add_gt=False)
num_imgs = len(imgs)
img_per_shard = num_imgs // num_shards
img_range = (shard * img_per_shard, (shard + 1) * img_per_shard if shard + 1 < num_shards else num_imgs)
# no filter for training # no filter for training
ds = DataFromListOfDict(imgs, ['file_name', 'id']) ds = DataFromListOfDict(imgs[img_range[0]: img_range[1]], ['file_name', 'id'])
def f(fname): def f(fname):
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
return im return im
ds = MapDataComponent(ds, f, 0) ds = MapDataComponent(ds, f, 0)
if cfg.TRAINER != 'horovod': # Evaluation itself may be multi-threaded, therefore don't add prefetch here.
ds = PrefetchDataZMQ(ds, 1)
return ds return ds
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import tqdm import tqdm
import os import os
from collections import namedtuple from collections import namedtuple
from contextlib import ExitStack
import numpy as np import numpy as np
import cv2 import cv2
...@@ -90,18 +91,24 @@ def detect_one_image(img, model_func): ...@@ -90,18 +91,24 @@ def detect_one_image(img, model_func):
return results return results
def eval_coco(df, detect_func): def eval_coco(df, detect_func, tqdm_bar=None):
""" """
Args: Args:
df: a DataFlow which produces (image, image_id) df: a DataFlow which produces (image, image_id)
detect_func: a callable, takes [image] and returns [DetectionResult] detect_func: a callable, takes [image] and returns [DetectionResult]
tqdm_bar: a tqdm object to be shared among multiple evaluation instances. If None,
will create a new one.
Returns: Returns:
list of dict, to be dumped to COCO json format list of dict, to be dumped to COCO json format
""" """
df.reset_state() df.reset_state()
all_results = [] all_results = []
with tqdm.tqdm(total=df.size(), **get_tqdm_kwargs()) as pbar: # tqdm is not quite thread-safe: https://github.com/tqdm/tqdm/issues/323
with ExitStack() as stack:
if tqdm_bar is None:
tqdm_bar = stack.enter_context(
tqdm.tqdm(total=df.size(), **get_tqdm_kwargs()))
for img, img_id in df.get_data(): for img, img_id in df.get_data():
results = detect_func(img) results = detect_func(img)
for r in results: for r in results:
...@@ -124,7 +131,7 @@ def eval_coco(df, detect_func): ...@@ -124,7 +131,7 @@ def eval_coco(df, detect_func):
rle['counts'] = rle['counts'].decode('ascii') rle['counts'] = rle['counts'].decode('ascii')
res['segmentation'] = rle res['segmentation'] = rle
all_results.append(res) all_results.append(res)
pbar.update(1) tqdm_bar.update(1)
return all_results return all_results
......
...@@ -12,6 +12,7 @@ import numpy as np ...@@ -12,6 +12,7 @@ import numpy as np
import json import json
import six import six
import tensorflow as tf import tensorflow as tf
from concurrent.futures import ThreadPoolExecutor
try: try:
import horovod.tensorflow as hvd import horovod.tensorflow as hvd
except ImportError: except ImportError:
...@@ -466,33 +467,54 @@ def predict(pred_func, input_file): ...@@ -466,33 +467,54 @@ def predict(pred_func, input_file):
class EvalCallback(Callback): class EvalCallback(Callback):
"""
A callback that runs COCO evaluation once a while.
It supports multi-GPU evaluation if TRAINER=='replicated' and single-GPU evaluation if TRAINER=='horovod'
"""
def __init__(self, in_names, out_names): def __init__(self, in_names, out_names):
self._in_names, self._out_names = in_names, out_names self._in_names, self._out_names = in_names, out_names
def _setup_graph(self): def _setup_graph(self):
self.pred = self.trainer.get_predictor(self._in_names, self._out_names) num_gpu = cfg.TRAIN.NUM_GPUS
self.df = get_eval_dataflow() # Use two predictor threads per GPU to get better throughput
self.num_predictor = 1 if cfg.TRAINER == 'horovod' else num_gpu * 2
self.predictors = [self._build_coco_predictor(k % num_gpu) for k in range(self.num_predictor)]
self.dataflows = [get_eval_dataflow(shard=k, num_shards=self.num_predictor)
for k in range(self.num_predictor)]
def _build_coco_predictor(self, idx):
graph_func = self.trainer.get_predictor(self._in_names, self._out_names, device=idx)
return lambda img: detect_one_image(img, graph_func)
def _before_train(self): def _before_train(self):
EVAL_TIMES = 5 # eval 5 times during training num_eval = cfg.TRAIN.NUM_EVALS
interval = self.trainer.max_epoch // (EVAL_TIMES + 1) interval = max(self.trainer.max_epoch // (num_eval + 1), 1)
self.epochs_to_eval = set([interval * k for k in range(1, EVAL_TIMES + 1)]) self.epochs_to_eval = set([interval * k for k in range(1, num_eval + 1)])
self.epochs_to_eval.add(self.trainer.max_epoch) self.epochs_to_eval.add(self.trainer.max_epoch)
logger.info("[EvalCallback] Will evaluate at epoch " + str(sorted(self.epochs_to_eval))) if len(self.epochs_to_eval) < 15:
logger.info("[EvalCallback] Will evaluate at epoch " + str(sorted(self.epochs_to_eval)))
else:
logger.info("[EvalCallback] Will evaluate every {} epochs".format(interval))
def _eval(self): def _eval(self):
all_results = eval_coco(self.df, lambda img: detect_one_image(img, self.pred)) with ThreadPoolExecutor(max_workers=self.num_predictor, thread_name_prefix='EvalWorker') as executor, \
tqdm.tqdm(total=sum([df.size() for df in self.dataflows])) as pbar:
futures = []
for dataflow, pred in zip(self.dataflows, self.predictors):
futures.append(executor.submit(eval_coco, dataflow, pred, pbar))
all_results = list(itertools.chain(*[fut.result() for fut in futures]))
output_file = os.path.join( output_file = os.path.join(
logger.get_logger_dir(), 'outputs{}.json'.format(self.global_step)) logger.get_logger_dir(), 'outputs{}.json'.format(self.global_step))
with open(output_file, 'w') as f: with open(output_file, 'w') as f:
json.dump(all_results, f) json.dump(all_results, f)
try: try:
scores = print_evaluation_scores(output_file) scores = print_evaluation_scores(output_file)
for k, v in scores.items():
self.trainer.monitors.put_scalar(k, v)
except Exception: except Exception:
logger.exception("Exception in COCO evaluation.") logger.exception("Exception in COCO evaluation.")
scores = {}
for k, v in scores.items():
self.trainer.monitors.put_scalar(k, v)
def _trigger_epoch(self): def _trigger_epoch(self):
if self.epoch_num in self.epochs_to_eval: if self.epoch_num in self.epochs_to_eval:
...@@ -558,7 +580,7 @@ if __name__ == '__main__': ...@@ -558,7 +580,7 @@ if __name__ == '__main__':
init_lr = cfg.TRAIN.BASE_LR * 0.33 * (8. / cfg.TRAIN.NUM_GPUS) init_lr = cfg.TRAIN.BASE_LR * 0.33 * (8. / cfg.TRAIN.NUM_GPUS)
warmup_schedule = [(0, init_lr), (cfg.TRAIN.WARMUP, cfg.TRAIN.BASE_LR)] warmup_schedule = [(0, init_lr), (cfg.TRAIN.WARMUP, cfg.TRAIN.BASE_LR)]
warmup_end_epoch = cfg.TRAIN.WARMUP * 1. / stepnum warmup_end_epoch = cfg.TRAIN.WARMUP * 1. / stepnum
lr_schedule = [(int(np.ceil(warmup_end_epoch)), warmup_schedule[-1][1])] lr_schedule = [(int(np.ceil(warmup_end_epoch)), cfg.TRAIN.BASE_LR)]
factor = 8. / cfg.TRAIN.NUM_GPUS factor = 8. / cfg.TRAIN.NUM_GPUS
for idx, steps in enumerate(cfg.TRAIN.LR_SCHEDULE[:-1]): for idx, steps in enumerate(cfg.TRAIN.LR_SCHEDULE[:-1]):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment