Commit b79a9d3b authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] parallel offline evaluation; NORM=None

parent f91a30ca
......@@ -17,8 +17,8 @@ This is a minimal implementation that simply contains these files:
Data:
1. It's easy to train on your own data. Just replace `COCODetection.load_many` in `data.py` by your own loader.
Also remember to change `config.NUM_CLASS` and `config.CLASS_NAMES`.
The current evaluation code is also COCO-specific, and you need to change it to use your data and metrics.
Also remember to change `DATA.NUM_CATEGORY` and `DATA.CLASS_NAMES` in the config.
The current evaluation code is also COCO-specific, and you may need to change it to use your data and metrics.
2. You can easily add more augmentations such as rotation, but be careful how a box should be
augmented. The code now will always use the minimal axis-aligned bounding box of the 4 corners,
......@@ -46,7 +46,7 @@ Model:
Speed:
1. If cudnn warmup is on, the training will start very slowly, until about
1. If CuDNN warmup is on, the training will start very slowly, until about
10k steps (or more if scale augmentation is used) to reach a maximum speed.
As a result, the ETA is also inaccurate at the beginning.
Warmup is by default on when no scale augmentation is used.
......
......@@ -8,7 +8,7 @@ following object detection / instance segmentation papers:
+ [Cascade R-CNN: Delving into High Quality Object Detection](https://arxiv.org/abs/1712.00726)
with the support of:
+ Multi-GPU / distributed training
+ Multi-GPU / distributed training, multi-GPU evaluation
+ Cross-GPU BatchNorm (aka Sync-BN, from [MegDet: A Large Mini-Batch Object Detector](https://arxiv.org/abs/1711.07240))
+ [Group Normalization](https://arxiv.org/abs/1803.08494)
+ Training from scratch (from [Rethinking ImageNet Pre-training](https://arxiv.org/abs/1811.08883))
......
......@@ -108,6 +108,8 @@ def image_preprocess(image, bgr=True):
def get_norm(zero_init=False):
if cfg.BACKBONE.NORM == 'None':
return lambda x: x
if cfg.BACKBONE.NORM == 'GN':
Norm = GroupNorm
layer_name = 'gn'
......@@ -144,7 +146,11 @@ def resnet_bottleneck(l, ch_out, stride):
l = Conv2D('conv2', l, ch_out, 3, strides=2, padding='VALID')
else:
l = Conv2D('conv2', l, ch_out, 3, strides=stride)
if cfg.BACKBONE.NORM != 'None':
l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_norm(zero_init=True))
else:
l = Conv2D('conv3', l, ch_out * 4, 1, activation=tf.identity,
kernel_initializer=tf.constant_initializer())
ret = l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_norm(zero_init=False))
return tf.nn.relu(ret, name='output')
......
......@@ -82,15 +82,16 @@ _C.DATA.BASEDIR = '/path/to/your/COCO/DIR'
_C.DATA.TRAIN = ['train2014', 'valminusminival2014'] # i.e. trainval35k, AKA train2017
# For now, only support evaluation on single dataset
_C.DATA.VAL = 'minival2014' # AKA val2017
_C.DATA.NUM_CATEGORY = 80 # 80 categories.
_C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, to be populated later by data loader. The first is BG.
_C.DATA.NUM_CATEGORY = 80 # 80 categories in COCO
_C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# For COCO, this list will be populated later by the COCO data loader.
# basemodel ----------------------
_C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz
_C.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCKS = [3, 4, 23, 3] # for resnet101
_C.BACKBONE.FREEZE_AFFINE = False # do not train affine parameters inside norm layers
_C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN, GN
_C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN, GN, None
_C.BACKBONE.FREEZE_AT = 2 # options: 0, 1, 2
# Use a base model with TF-preferred padding mode,
......@@ -208,7 +209,7 @@ def finalize_configs(is_training):
_C.DATA.NUM_CLASS = _C.DATA.NUM_CATEGORY + 1 # +1 background
_C.DATA.BASEDIR = os.path.expanduser(_C.DATA.BASEDIR)
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN'], _C.BACKBONE.NORM
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN', 'None'], _C.BACKBONE.NORM
if _C.BACKBONE.NORM != 'FreezeBN':
assert not _C.BACKBONE.FREEZE_AFFINE
assert _C.BACKBONE.FREEZE_AT in [0, 1, 2]
......@@ -246,8 +247,13 @@ def finalize_configs(is_training):
else:
assert 'OMPI_COMM_WORLD_SIZE' not in os.environ
ngpu = get_num_gpu()
assert ngpu % 8 == 0 or 8 % ngpu == 0, "Can only train with 1,2,4 or >=8 GPUs, but found {} GPUs".format(ngpu)
else:
# autotune is too slow for inference
os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
ngpu = get_num_gpu()
assert ngpu > 0, "Has to run with GPU!"
assert ngpu % 8 == 0 or 8 % ngpu == 0, "Can only run with 1,2,4 or >=8 GPUs, but found {} GPUs".format(ngpu)
if _C.TRAIN.NUM_GPUS is None:
_C.TRAIN.NUM_GPUS = ngpu
else:
......@@ -255,9 +261,6 @@ def finalize_configs(is_training):
assert _C.TRAIN.NUM_GPUS == ngpu
else:
assert _C.TRAIN.NUM_GPUS <= ngpu
else:
# autotune is too slow for inference
os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
_C.freeze()
logger.info("Config: ------------------------------------------\n" + str(_C))
......@@ -5,8 +5,10 @@ import tqdm
import os
from collections import namedtuple
from contextlib import ExitStack
import itertools
import numpy as np
import cv2
from concurrent.futures import ThreadPoolExecutor
from tensorpack.utils.utils import get_tqdm_kwargs
......@@ -29,12 +31,14 @@ mask: None, or a binary image of the original image shape
"""
def fill_full_mask(box, mask, shape):
def paste_mask(box, mask, shape):
"""
Args:
box: 4 float
mask: MxM floats
shape: h,w
Returns:
A uint8 binary image of hxw.
"""
# int() is floor
# box fpcoor=0.0 -> intcoor=0.0
......@@ -80,7 +84,7 @@ def detect_one_image(img, model_func):
if masks:
# has mask
full_masks = [fill_full_mask(box, mask, orig_shape)
full_masks = [paste_mask(box, mask, orig_shape)
for box, mask in zip(boxes, masks[0])]
masks = full_masks
else:
......@@ -135,8 +139,30 @@ def eval_coco(df, detect_func, tqdm_bar=None):
return all_results
def multithread_eval_coco(dataflows, detect_funcs):
"""
Running multiple `eval_coco` in multiple threads, and aggregate the results.
Args:
dataflows: a list of DataFlow to be used in :func:`eval_coco`
detect_funcs: a list of callable to be used in :func:`eval_coco`
Returns:
list of dict, to be dumped to COCO json format
"""
num_worker = len(dataflows)
assert len(dataflows) == len(detect_funcs)
with ThreadPoolExecutor(max_workers=num_worker, thread_name_prefix='EvalWorker') as executor, \
tqdm.tqdm(total=sum([df.size() for df in dataflows])) as pbar:
futures = []
for dataflow, pred in zip(dataflows, detect_funcs):
futures.append(executor.submit(eval_coco, dataflow, pred, pbar))
all_results = list(itertools.chain(*[fut.result() for fut in futures]))
return all_results
# https://github.com/pdollar/coco/blob/master/PythonAPI/pycocoEvalDemo.ipynb
def print_evaluation_scores(json_file):
def print_coco_metrics(json_file):
ret = {}
assert cfg.DATA.BASEDIR and os.path.isdir(cfg.DATA.BASEDIR)
annofile = os.path.join(
......
......@@ -12,7 +12,6 @@ import numpy as np
import json
import six
import tensorflow as tf
from concurrent.futures import ThreadPoolExecutor
try:
import horovod.tensorflow as hvd
except ImportError:
......@@ -52,7 +51,8 @@ from viz import (
draw_annotation, draw_proposal_recall,
draw_predictions, draw_final_outputs)
from eval import (
eval_coco, detect_one_image, print_evaluation_scores, DetectionResult)
eval_coco, multithread_eval_coco,
detect_one_image, print_coco_metrics, DetectionResult)
from config import finalize_configs, config as cfg
......@@ -388,13 +388,23 @@ def visualize(model, model_path, nr_visualize=100, output_dir='output'):
pbar.update()
def offline_evaluate(pred_func, output_file):
df = get_eval_dataflow()
all_results = eval_coco(
df, lambda img: detect_one_image(img, pred_func))
def offline_evaluate(pred_config, output_file):
num_gpu = cfg.TRAIN.NUM_GPUS
graph_funcs = MultiTowerOfflinePredictor(
pred_config, list(range(num_gpu))).get_predictors()
predictors = []
dataflows = []
for k in range(num_gpu):
predictors.append(lambda img,
pred=graph_funcs[k]: detect_one_image(img, pred))
dataflows.append(get_eval_dataflow(shard=k, num_shards=num_gpu))
if num_gpu > 1:
all_results = multithread_eval_coco(dataflows, predictors)
else:
all_results = eval_coco(dataflows[0], predictors[0])
with open(output_file, 'w') as f:
json.dump(all_results, f)
print_evaluation_scores(output_file)
print_coco_metrics(output_file)
def predict(pred_func, input_file):
......@@ -456,12 +466,7 @@ class EvalCallback(Callback):
def _eval(self):
logdir = args.logdir
if cfg.TRAINER == 'replicated':
with ThreadPoolExecutor(max_workers=self.num_predictor, thread_name_prefix='EvalWorker') as executor, \
tqdm.tqdm(total=sum([df.size() for df in self.dataflows])) as pbar:
futures = []
for dataflow, pred in zip(self.dataflows, self.predictors):
futures.append(executor.submit(eval_coco, dataflow, pred, pbar))
all_results = list(itertools.chain(*[fut.result() for fut in futures]))
all_results = multithread_eval_coco(self.dataflows, self.predictors)
else:
filenames = [os.path.join(
logdir, 'outputs{}-part{}.json'.format(self.global_step, rank)
......@@ -487,7 +492,7 @@ class EvalCallback(Callback):
with open(output_file, 'w') as f:
json.dump(all_results, f)
try:
scores = print_evaluation_scores(output_file)
scores = print_coco_metrics(output_file)
for k, v in scores.items():
self.trainer.monitors.put_scalar(k, v)
except Exception:
......@@ -532,17 +537,17 @@ if __name__ == '__main__':
if args.visualize:
visualize(MODEL, args.load)
else:
pred = OfflinePredictor(PredictConfig(
predcfg = PredictConfig(
model=MODEL,
session_init=get_model_loader(args.load),
input_names=MODEL.get_inference_tensor_names()[0],
output_names=MODEL.get_inference_tensor_names()[1]))
if args.evaluate:
assert args.evaluate.endswith('.json'), args.evaluate
offline_evaluate(pred, args.evaluate)
elif args.predict:
output_names=MODEL.get_inference_tensor_names()[1])
if args.predict:
COCODetection(cfg.DATA.BASEDIR, 'val2014') # Only to load the class names into caches
predict(pred, args.predict)
predict(OfflinePredictor(predcfg), args.predict)
elif args.evaluate:
assert args.evaluate.endswith('.json'), args.evaluate
offline_evaluate(predcfg, args.evaluate)
else:
is_horovod = cfg.TRAINER == 'horovod'
if is_horovod:
......
......@@ -122,7 +122,7 @@ class Contrast(ImageAugmentor):
else:
mean = np.mean(img)
img = (img - mean) * r + mean
img = img * r + mean * (1 - r)
if self.clip or old_dtype == np.uint8:
img = np.clip(img, 0, 255)
return img.astype(old_dtype)
......
......@@ -37,9 +37,11 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for idx, t in enumerate(towers):
tower_name = 'tower' + str(t)
device = '/gpu:{}'.format(t)
with tf.variable_scope(tf.get_variable_scope(), reuse=idx > 0), \
tf.device('/gpu:{}'.format(t)), \
tf.device(device), \
PredictTowerContext(tower_name):
logger.info("Building graph for predict tower '{}' on device {} ...".format(tower_name, device))
config.tower_func(*input.get_input_tensors())
handles.append(config.tower_func.towers[-1])
......
......@@ -80,7 +80,7 @@ class MismatchLogger(object):
self._names = []
def add(self, name):
self._names.append(name)
self._names.append(get_op_tensor_name(name)[0])
def log(self):
if len(self._names):
......
......@@ -63,10 +63,12 @@ class SessionUpdate(object):
varshape = tuple(var.get_shape().as_list())
if varshape != val.shape:
# TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(val.shape), \
"{}: {}!={}".format(name, varshape, val.shape)
logger.warn("Variable {} is reshaped {}->{} during assigning".format(
name, val.shape, varshape))
if np.prod(varshape) != np.prod(val.shape):
raise ValueError(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}.".format(
val.shape, name, varshape))
logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format(
val.shape, varshape, name))
val = val.reshape(varshape)
# fix some common type incompatibility problems, but not all
......
......@@ -13,6 +13,9 @@ __all__ = ['change_gpu', 'get_nr_gpu', 'get_num_gpu']
def change_gpu(val):
"""
Args:
val: an integer, the index of the GPU or -1 to disable GPU.
Returns:
a context where ``CUDA_VISIBLE_DEVICES=val``.
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment