Commit 141ab53c authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use un-quantized anchors; use better postprocessing; use 1x schedule

parent ae7b0774
......@@ -54,11 +54,8 @@ Model:
4. Because of (3), BatchNorm statistics are supposed to be freezed during fine-tuning.
5. An alternative to freezing BatchNorm is to sync BatchNorm statistics across
GPUs (the `BACKBONE.NORM=SyncBN` option). This would require [my bugfix](https://github.com/tensorflow/tensorflow/pull/20360)
which is available since TF 1.10. You can manually apply the patch to use it.
For now the total batch size is at most 8, so this option does not improve the model by much.
6. Another alternative to BatchNorm is GroupNorm (`BACKBONE.NORM=GN`) which has better performance.
GPUs (the `BACKBONE.NORM=SyncBN` option).
Another alternative to BatchNorm is GroupNorm (`BACKBONE.NORM=GN`) which has better performance.
Efficiency:
......@@ -74,14 +71,22 @@ Efficiency:
1. This implementation does not use specialized CUDA ops (e.g. AffineChannel, ROIAlign).
Therefore it might be slower than other highly-optimized implementations.
1. To reduce RAM usage on host: (1) make sure you're using the "spawn" method as
set in `train.py`; (2) reduce `buffer_size` or `NUM_WORKERS` in `data.py`
(which may negatively impact your throughput). The training needs <10G RAM if `NUM_WORKERS=0`.
1. Inference is unoptimized. Tensorpack is a training interface, therefore it
does not help you on optimized inference.
1. To reduce RAM usage on host: (1) make sure you're using the "spawn" method as
set in `train.py`; (2) reduce `buffer_size` or `NUM_WORKERS` in `data.py`
(which may negatively impact your throughput). The training needs <10G RAM if `NUM_WORKERS=0`.
1. Inference is unoptimized. Tensorpack is a training interface, therefore it
does not help you on optimized inference. In fact, the current implementation
uses some slow numpy operations in inference (in `eval.py:_paste_mask`).
Possible Future Enhancements:
1. Define a better interface to load different datasets.
......
......@@ -82,42 +82,38 @@ All models are fine-tuned from ImageNet pre-trained R50/R101 models in
[tensorpack model zoo](http://models.tensorpack.com/FasterRCNN/), unless otherwise noted.
All models are trained with 8 NVIDIA V100s, unless otherwise noted.
Performance in [Detectron](https://github.com/facebookresearch/Detectron/) can
be approximately reproduced.
| Backbone | mAP<br/>(box;mask) | Detectron mAP <sup>[1](#ft1)</sup><br/> (box;mask) | Time <br/>(on 8 V100s) | Configurations <br/> (click to expand) |
| - | - | - | - | - |
| R50-C4 | 33.5 | | 17h | <details><summary>super quick</summary>`MODE_MASK=False FRCNN.BATCH_PER_IM=64`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=600 PREPROC.MAX_SIZE=1024`<br/>`TRAIN.LR_SCHEDULE=[150000,230000,280000]` </details> |
| R50-C4 | 36.6 | 36.5 | 44h | <details><summary>standard</summary>`MODE_MASK=False` </details> |
| R50-FPN | 37.4 | 37.9 | 23h | <details><summary>standard</summary>`MODE_MASK=False MODE_FPN=True` </details> |
| R50-C4 | 38.2;33.3 [:arrow_down:][R50C42x] | 37.8;32.8 | 49h | <details><summary>standard</summary>this is the default </details> |
| R50-FPN | 38.4;35.1 [:arrow_down:][R50FPN2x] | 38.6;34.5 | 27h | <details><summary>standard</summary>`MODE_FPN=True` </details> |
| R50-FPN | 42.0;36.3 | | 36h | <details><summary>+Cascade</summary>`MODE_FPN=True FPN.CASCADE=True` </details> |
| R50-FPN | 39.5;35.2 | 39.5;34.4<sup>[2](#ft2)</sup> | 31h | <details><summary>+ConvGNHead</summary>`MODE_FPN=True`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head` </details> |
| R50-FPN | 40.0;36.2 [:arrow_down:][R50FPN2xGN] | 40.3;35.7 | 33h | <details><summary>+GN</summary>`MODE_FPN=True`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head` |
| R101-C4 | 41.4;35.2 [:arrow_down:][R101C42x] | | 60h | <details><summary>standard</summary>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> |
| R101-FPN | 40.4;36.6 [:arrow_down:][R101FPN2x] | 40.9;36.4 | 37h | <details><summary>standard</summary>`MODE_FPN=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> |
| R101-FPN | 46.5;40.1 [:arrow_down:][R101FPN3xCasAug] <sup>[3](#ft3)</sup> | | 73h | <details><summary>3x+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`TEST.RESULT_SCORE_THRESH=1e-4`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=[420000,500000,540000]` </details> |
| R101-FPN<br/>(From Scratch) | 47.5;41.2 [:arrow_down:][R101FPN9xGNCasAugScratch] | 47.4;40.5<sup>[4](#ft4)</sup> | 45h <br/>(on 48 V100s) | <details><summary>9x+GN+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=[1500000,1580000,1620000]`<br/>`BACKBONE.FREEZE_AT=0`</details> |
[R50C42x]: http://models.tensorpack.com/FasterRCNN/COCO-R50C4-MaskRCNN-Standard.npz
[R50FPN2x]: http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-Standard.npz
[R50FPN2xGN]: http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-StandardGN.npz
[R101C42x]: http://models.tensorpack.com/FasterRCNN/COCO-R101C4-MaskRCNN-Standard.npz
[R101FPN2x]: http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-Standard.npz
[R101FPN3xCasAug]: http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-BetterParams.npz
[R101FPN9xGNCasAugScratch]: http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-ScratchGN.npz
Performance in [Detectron](https://github.com/facebookresearch/Detectron/) can be reproduced.
| Backbone | mAP<br/>(box;mask) | Detectron mAP <sup>[1](#ft1)</sup><br/> (box;mask) | Time <br/>(on 8 V100s) | Configurations <br/> (click to expand) |
| - | - | - | - | - |
| R50-C4 | 34.1 | | 7.5h | <details><summary>super quick</summary>`MODE_MASK=False FRCNN.BATCH_PER_IM=64`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=600 PREPROC.MAX_SIZE=1024`<br/>`TRAIN.LR_SCHEDULE=[140000,180000,200000]` </details> |
| R50-C4 | 35.6 | 34.8 | 23h | <details><summary>standard</summary>`MODE_MASK=False` </details> |
| R50-FPN | 37.5 | 36.7 | 11h | <details><summary>standard</summary>`MODE_MASK=False MODE_FPN=True` </details> |
| R50-C4 | 36.2;31.8 [:arrow_down:][R50C41x] | 35.8;31.4 | 23.5h | <details><summary>standard</summary>this is the default </details> |
| R50-FPN | 38.2;34.8 | 37.7;33.9 | 13.5h | <details><summary>standard</summary>`MODE_FPN=True` </details> |
| R50-FPN | 38.9;35.4 [:arrow_down:][R50FPN2x] | 38.6;34.5 | 25h | <details><summary>2x</summary>`MODE_FPN=True`<br/>`TRAIN.LR_SCHEDULE=[240000,320000,360000]` </details> |
| R50-FPN-GN | 40.4;36.3 [:arrow_down:][R50FPN2xGN] | 40.3;35.7 | 31h | <details><summary>2x+GN</summary>`MODE_FPN=True`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head` |
| R50-FPN | 41.7;36.2 | | 17h | <details><summary>+Cascade</summary>`MODE_FPN=True FPN.CASCADE=True` </details> |
| R101-C4 | 40.1;34.6 [:arrow_down:][R101C41x] | | 28h | <details><summary>standard</summary>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> |
| R101-FPN | 40.7;36.8 [:arrow_down:][R101FPN1x] | 40.0;35.9 | 18h | <details><summary>standard</summary>`MODE_FPN=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> |
| R101-FPN | 46.6;40.3 [:arrow_down:][R101FPN3xCasAug] <sup>[2](#ft2)</sup> | | 69h | <details><summary>3x+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`TEST.RESULT_SCORE_THRESH=1e-4`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=[420000,500000,540000]` </details> |
| R101-FPN-GN<br/>(From Scratch) | 47.7;41.7 [:arrow_down:][R101FPN9xGNCasAugScratch]<sup>[3](#ft3)</sup> | 47.4;40.5 | 28h (on 64 V100s) | <details><summary>9x+GN+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=[1500000,1580000,1620000]`<br/>`BACKBONE.FREEZE_AT=0`</details> |
[R50C41x]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R50C41x.npz
[R50FPN2x]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R50FPN2x.npz
[R50FPN2xGN]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R50FPN2xGN.npz
[R101C41x]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R101C41x.npz
[R101FPN1x]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R101FPN1x.npz
[R101FPN3xCasAug]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R101FPN3xCasAug.npz
[R101FPN9xGNCasAugScratch]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R101FPN9xGNCasAugScratch.npz
<a id="ft1">1</a>: Numbers taken from [Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md).
We compare models that have identical training & inference cost between the two implementations. Their numbers can be different due to many small implementation details.
For example, our FPN models are sometimes slightly worse in box AP, which is
mainly due to batch size.
We compare models that have identical training & inference cost between the two implementations.
Their numbers can be different due to small implementation details.
<a id="ft2">2</a>: Numbers taken from Table 5 in [Group Normalization](https://arxiv.org/abs/1803.08494)
<a id="ft2">2</a>: Our mAP is __10+ point__ better than the official model in [matterport/Mask_RCNN](https://github.com/matterport/Mask_RCNN/releases/tag/v2.0) with the same R101-FPN backbone.
<a id="ft3">3</a>: Our mAP is __10+ point__ better than the official model in [matterport/Mask_RCNN](https://github.com/matterport/Mask_RCNN/releases/tag/v2.0) with the same R101-FPN backbone.
<a id="ft4">4</a>: This entry does not use ImageNet pre-training. Detectron numbers are taken from Fig. 5 in [Rethinking ImageNet Pre-training](https://arxiv.org/abs/1811.08883).
<a id="ft3">3</a>: This entry does not use ImageNet pre-training. Detectron numbers are taken from Fig. 5 in [Rethinking ImageNet Pre-training](https://arxiv.org/abs/1811.08883).
Note that our training strategy is slightly different: we enable cascade throughout the entire training.
As far as I know, this model is the __best open source model__ on COCO dataset.
......
......@@ -130,8 +130,8 @@ _C.TRAIN.STARTING_EPOCH = 1 # the first epoch to start with, useful to continue
# the base learning rate are computed from BASE_LR and LR_SCHEDULE.
# Therefore, there is *no need* to modify the config if you only change the number of GPUs.
# _C.TRAIN.LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
_C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
_C.TRAIN.LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
# _C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
# Longer schedules for from-scratch training (https://arxiv.org/abs/1811.08883):
# _C.TRAIN.LR_SCHEDULE = [960000, 1040000, 1080000] # "6x" schedule in detectron
# _C.TRAIN.LR_SCHEDULE = [1500000, 1580000, 1620000] # "9x" schedule in detectron
......
......@@ -9,17 +9,18 @@ from tabulate import tabulate
from termcolor import colored
from tensorpack.dataflow import (
DataFromList, MapData, MapDataComponent, MultiProcessMapData, MultiThreadMapData,
TestDataSpeed, imgaug)
DataFromList, MapData, MapDataComponent,
MultiProcessMapData, MultiThreadMapData, TestDataSpeed, imgaug,
)
from tensorpack.utils import logger
from tensorpack.utils.argtools import log_once, memoized
from common import (
CustomResize, DataFromListOfDict, box_to_point8, filter_boxes_inside_shape, np_iou,
point8_to_box, segmentation_to_mask)
CustomResize, DataFromListOfDict, box_to_point8,
filter_boxes_inside_shape, np_iou, point8_to_box, segmentation_to_mask,
)
from config import config as cfg
from dataset import DatasetRegistry
from utils.generate_anchors import generate_anchors
from utils.np_box_ops import area as np_area
from utils.np_box_ops import ioa as np_ioa
......@@ -42,15 +43,14 @@ def print_class_histogram(roidbs):
gt_hist = np.zeros((cfg.DATA.NUM_CATEGORY + 1,), dtype=np.int)
for entry in roidbs:
# filter crowd?
gt_inds = np.where(
(entry['class'] > 0) & (entry['is_crowd'] == 0))[0]
gt_classes = entry['class'][gt_inds]
gt_inds = np.where((entry["class"] > 0) & (entry["is_crowd"] == 0))[0]
gt_classes = entry["class"][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = [[cfg.DATA.CLASS_NAMES[i], v] for i, v in enumerate(gt_hist)]
data.append(['total', sum(x[1] for x in data)])
data.append(["total", sum(x[1] for x in data)])
# the first line is BG
table = tabulate(data[1:], headers=['class', '#box'], tablefmt='pipe')
logger.info("Ground-Truth Boxes:\n" + colored(table, 'cyan'))
table = tabulate(data[1:], headers=["class", "#box"], tablefmt="pipe")
logger.info("Ground-Truth Boxes:\n" + colored(table, "cyan"))
@memoized
......@@ -69,17 +69,17 @@ def get_all_anchors(*, stride, sizes, ratios, max_size):
"""
# Generates a NAx4 matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors
# are centered on stride / 2, have (approximate) sqrt areas of the specified
# sizes, and aspect ratios as given.
cell_anchors = generate_anchors(
stride,
scales=np.array(sizes, dtype=np.float) / stride,
ratios=np.array(ratios, dtype=np.float))
# anchors are intbox here.
# anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride)
# are centered on 0, have sqrt areas equal to the specified sizes, and aspect ratios as given.
anchors = []
for sz in sizes:
for ratio in ratios:
w = np.sqrt(sz * sz / ratio)
h = ratio * w
anchors.append([-w, -h, w, h])
cell_anchors = np.asarray(anchors) * 0.5
field_size = int(np.ceil(max_size / stride))
shifts = np.arange(0, field_size) * stride
shifts = (np.arange(0, field_size) * stride).astype("float32")
shift_x, shift_y = np.meshgrid(shifts, shifts)
shift_x = shift_x.flatten()
shift_y = shift_y.flatten()
......@@ -88,15 +88,12 @@ def get_all_anchors(*, stride, sizes, ratios, max_size):
K = shifts.shape[0]
A = cell_anchors.shape[0]
field_of_anchors = (
cell_anchors.reshape((1, A, 4)) +
shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
field_of_anchors = cell_anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2))
field_of_anchors = field_of_anchors.reshape((field_size, field_size, A, 4))
# FSxFSxAx4
# Many rounding happens inside the anchor code anyway
# assert np.all(field_of_anchors == field_of_anchors.astype('int32'))
field_of_anchors = field_of_anchors.astype('float32')
field_of_anchors[:, :, :, [2, 3]] += 1
field_of_anchors = field_of_anchors.astype("float32")
return field_of_anchors
......@@ -121,18 +118,19 @@ class TrainingDataPreprocessor:
Since the mapping may run in other processes, we write a new class and
explicitly pass cfg to it, in the spirit of "explicitly pass resources to subprocess".
"""
def __init__(self, cfg):
self.cfg = cfg
self.aug = imgaug.AugmentorList(
[CustomResize(cfg.PREPROC.TRAIN_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE),
imgaug.Flip(horiz=True)])
[CustomResize(cfg.PREPROC.TRAIN_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE), imgaug.Flip(horiz=True)]
)
def __call__(self, roidb):
fname, boxes, klass, is_crowd = roidb['file_name'], roidb['boxes'], roidb['class'], roidb['is_crowd']
fname, boxes, klass, is_crowd = roidb["file_name"], roidb["boxes"], roidb["class"], roidb["is_crowd"]
boxes = np.copy(boxes)
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
im = im.astype('float32')
im = im.astype("float32")
height, width = im.shape[:2]
# assume floatbox as input
assert boxes.dtype == np.float32, "Loader has to return floating point boxes!"
......@@ -148,30 +146,30 @@ class TrainingDataPreprocessor:
boxes = point8_to_box(points)
assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!"
ret = {'image': im}
ret = {"image": im}
# Add rpn data to dataflow:
try:
if self.cfg.MODE_FPN:
multilevel_anchor_inputs = self.get_multilevel_rpn_anchor_input(im, boxes, is_crowd)
for i, (anchor_labels, anchor_boxes) in enumerate(multilevel_anchor_inputs):
ret['anchor_labels_lvl{}'.format(i + 2)] = anchor_labels
ret['anchor_boxes_lvl{}'.format(i + 2)] = anchor_boxes
ret["anchor_labels_lvl{}".format(i + 2)] = anchor_labels
ret["anchor_boxes_lvl{}".format(i + 2)] = anchor_boxes
else:
ret['anchor_labels'], ret['anchor_boxes'] = self.get_rpn_anchor_input(im, boxes, is_crowd)
ret["anchor_labels"], ret["anchor_boxes"] = self.get_rpn_anchor_input(im, boxes, is_crowd)
boxes = boxes[is_crowd == 0] # skip crowd boxes in training target
boxes = boxes[is_crowd == 0] # skip crowd boxes in training target
klass = klass[is_crowd == 0]
ret['gt_boxes'] = boxes
ret['gt_labels'] = klass
ret["gt_boxes"] = boxes
ret["gt_labels"] = klass
if not len(boxes):
raise MalformedData("No valid gt_boxes!")
except MalformedData as e:
log_once("Input {} is filtered for training: {}".format(fname, str(e)), 'warn')
log_once("Input {} is filtered for training: {}".format(fname, str(e)), "warn")
return None
if self.cfg.MODE_MASK:
# augmentation will modify the polys in-place
segmentation = copy.deepcopy(roidb['segmentation'])
segmentation = copy.deepcopy(roidb["segmentation"])
segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]]
assert len(segmentation) == len(boxes)
......@@ -210,11 +208,14 @@ class TrainingDataPreprocessor:
NA will be NUM_ANCHOR_SIZES x NUM_ANCHOR_RATIOS
"""
boxes = boxes.copy()
all_anchors = np.copy(get_all_anchors(
stride=self.cfg.RPN.ANCHOR_STRIDE,
sizes=self.cfg.RPN.ANCHOR_SIZES,
ratios=self.cfg.RPN.ANCHOR_RATIOS,
max_size=self.cfg.PREPROC.MAX_SIZE))
all_anchors = np.copy(
get_all_anchors(
stride=self.cfg.RPN.ANCHOR_STRIDE,
sizes=self.cfg.RPN.ANCHOR_SIZES,
ratios=self.cfg.RPN.ANCHOR_RATIOS,
max_size=self.cfg.PREPROC.MAX_SIZE,
)
)
# fHxfWxAx4 -> (-1, 4)
featuremap_anchors_flatten = all_anchors.reshape((-1, 4))
......@@ -222,15 +223,16 @@ class TrainingDataPreprocessor:
inside_ind, inside_anchors = filter_boxes_inside_shape(featuremap_anchors_flatten, im.shape[:2])
# obtain anchor labels and their corresponding gt boxes
anchor_labels, anchor_gt_boxes = self.get_anchor_labels(
inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1])
inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1]
)
# Fill them back to original size: fHxfWx1, fHxfWx4
num_anchor = self.cfg.RPN.NUM_ANCHOR
anchorH, anchorW = all_anchors.shape[:2]
featuremap_labels = -np.ones((anchorH * anchorW * num_anchor, ), dtype='int32')
featuremap_labels = -np.ones((anchorH * anchorW * num_anchor,), dtype="int32")
featuremap_labels[inside_ind] = anchor_labels
featuremap_labels = featuremap_labels.reshape((anchorH, anchorW, num_anchor))
featuremap_boxes = np.zeros((anchorH * anchorW * num_anchor, 4), dtype='float32')
featuremap_boxes = np.zeros((anchorH * anchorW * num_anchor, 4), dtype="float32")
featuremap_boxes[inside_ind, :] = anchor_gt_boxes
featuremap_boxes = featuremap_boxes.reshape((anchorH, anchorW, num_anchor, 4))
return featuremap_labels, featuremap_boxes
......@@ -254,32 +256,33 @@ class TrainingDataPreprocessor:
strides=self.cfg.FPN.ANCHOR_STRIDES,
sizes=self.cfg.RPN.ANCHOR_SIZES,
ratios=self.cfg.RPN.ANCHOR_RATIOS,
max_size=self.cfg.PREPROC.MAX_SIZE)
max_size=self.cfg.PREPROC.MAX_SIZE,
)
flatten_anchors_per_level = [k.reshape((-1, 4)) for k in anchors_per_level]
all_anchors_flatten = np.concatenate(flatten_anchors_per_level, axis=0)
inside_ind, inside_anchors = filter_boxes_inside_shape(all_anchors_flatten, im.shape[:2])
anchor_labels, anchor_gt_boxes = self.get_anchor_labels(
inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1])
inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1]
)
# map back to all_anchors, then split to each level
num_all_anchors = all_anchors_flatten.shape[0]
all_labels = -np.ones((num_all_anchors, ), dtype='int32')
all_labels = -np.ones((num_all_anchors,), dtype="int32")
all_labels[inside_ind] = anchor_labels
all_boxes = np.zeros((num_all_anchors, 4), dtype='float32')
all_boxes = np.zeros((num_all_anchors, 4), dtype="float32")
all_boxes[inside_ind] = anchor_gt_boxes
start = 0
multilevel_inputs = []
for level_anchor in anchors_per_level:
assert level_anchor.shape[2] == len(self.cfg.RPN.ANCHOR_RATIOS)
anchor_shape = level_anchor.shape[:3] # fHxfWxNUM_ANCHOR_RATIOS
anchor_shape = level_anchor.shape[:3] # fHxfWxNUM_ANCHOR_RATIOS
num_anchor_this_level = np.prod(anchor_shape)
end = start + num_anchor_this_level
multilevel_inputs.append(
(all_labels[start: end].reshape(anchor_shape),
all_boxes[start: end, :].reshape(anchor_shape + (4,))
))
(all_labels[start:end].reshape(anchor_shape), all_boxes[start:end, :].reshape(anchor_shape + (4,)))
)
start = end
assert end == num_all_anchors, "{} != {}".format(end, num_all_anchors)
return multilevel_inputs
......@@ -300,10 +303,8 @@ class TrainingDataPreprocessor:
def filter_box_label(labels, value, max_num):
curr_inds = np.where(labels == value)[0]
if len(curr_inds) > max_num:
disable_inds = np.random.choice(
curr_inds, size=(len(curr_inds) - max_num),
replace=False)
labels[disable_inds] = -1 # ignore them
disable_inds = np.random.choice(curr_inds, size=(len(curr_inds) - max_num), replace=False)
labels[disable_inds] = -1 # ignore them
curr_inds = np.where(labels == value)[0]
return curr_inds
......@@ -317,7 +318,7 @@ class TrainingDataPreprocessor:
anchors_with_max_iou_per_gt = np.where(box_ious == ious_max_per_gt)[0]
# Setting NA labels: 1--fg 0--bg -1--ignore
anchor_labels = -np.ones((NA,), dtype='int32') # NA,
anchor_labels = -np.ones((NA,), dtype="int32") # NA,
# the order of setting neg/pos labels matter
anchor_labels[anchors_with_max_iou_per_gt] = 1
......@@ -345,10 +346,10 @@ class TrainingDataPreprocessor:
# No valid bg in this image, skip.
raise MalformedData("No valid background for RPN!")
target_num_bg = self.cfg.RPN.BATCH_PER_IM - len(fg_inds)
filter_box_label(anchor_labels, 0, target_num_bg) # ignore return values
filter_box_label(anchor_labels, 0, target_num_bg) # ignore return values
# Set anchor boxes: the best gt_box for each fg anchor
anchor_boxes = np.zeros((NA, 4), dtype='float32')
anchor_boxes = np.zeros((NA, 4), dtype="float32")
fg_boxes = gt_boxes[ious_argmax_per_anchor[fg_inds], :]
anchor_boxes[fg_inds, :] = fg_boxes
# assert len(fg_inds) + np.sum(anchor_labels == 0) == self.cfg.RPN.BATCH_PER_IM
......@@ -377,16 +378,19 @@ def get_train_dataflow():
# Valid training images should have at least one fg box.
# But this filter shall not be applied for testing.
num = len(roidbs)
roidbs = list(filter(lambda img: len(img['boxes'][img['is_crowd'] == 0]) > 0, roidbs))
logger.info("Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format(
num - len(roidbs), len(roidbs)))
roidbs = list(filter(lambda img: len(img["boxes"][img["is_crowd"] == 0]) > 0, roidbs))
logger.info(
"Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format(
num - len(roidbs), len(roidbs)
)
)
ds = DataFromList(roidbs, shuffle=True)
preprocess = TrainingDataPreprocessor(cfg)
if cfg.DATA.NUM_WORKERS > 0:
if cfg.TRAINER == 'horovod':
if cfg.TRAINER == "horovod":
buffer_size = cfg.DATA.NUM_WORKERS * 10 # one dataflow for each process, therefore don't need large buffer
ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
# MPI does not like fork()
......@@ -412,21 +416,23 @@ def get_eval_dataflow(name, shard=0, num_shards=1):
img_range = (shard * img_per_shard, (shard + 1) * img_per_shard if shard + 1 < num_shards else num_imgs)
# no filter for training
ds = DataFromListOfDict(roidbs[img_range[0]: img_range[1]], ['file_name', 'image_id'])
ds = DataFromListOfDict(roidbs[img_range[0]: img_range[1]], ["file_name", "image_id"])
def f(fname):
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
return im
ds = MapDataComponent(ds, f, 0)
# Evaluation itself may be multi-threaded, therefore don't add prefetch here.
return ds
if __name__ == '__main__':
if __name__ == "__main__":
import os
from tensorpack.dataflow import PrintData
cfg.DATA.BASEDIR = os.path.expanduser('~/data/coco')
cfg.DATA.BASEDIR = os.path.expanduser("~/data/coco")
ds = get_train_dataflow()
ds = PrintData(ds, 100)
TestDataSpeed(ds, 50000).start()
......
......@@ -13,6 +13,7 @@ from contextlib import ExitStack
import cv2
import pycocotools.mask as cocomask
import tqdm
from scipy import interpolate
from tensorpack.callbacks import Callback
from tensorpack.tfutils.common import get_tf_version_tuple
......@@ -41,6 +42,23 @@ mask: None, or a binary image of the original image shape
"""
def _scale_box(box, scale):
w_half = (box[2] - box[0]) * 0.5
h_half = (box[3] - box[1]) * 0.5
x_c = (box[2] + box[0]) * 0.5
y_c = (box[3] + box[1]) * 0.5
w_half *= scale
h_half *= scale
scaled_box = np.zeros_like(box)
scaled_box[0] = x_c - w_half
scaled_box[2] = x_c + w_half
scaled_box[1] = y_c - h_half
scaled_box[3] = y_c + h_half
return scaled_box
def _paste_mask(box, mask, shape):
"""
Args:
......@@ -50,23 +68,42 @@ def _paste_mask(box, mask, shape):
Returns:
A uint8 binary image of hxw.
"""
# int() is floor
# box fpcoor=0.0 -> intcoor=0.0
x0, y0 = list(map(int, box[:2] + 0.5))
# box fpcoor=h -> intcoor=h-1, inclusive
x1, y1 = list(map(int, box[2:] - 0.5)) # inclusive
x1 = max(x0, x1) # require at least 1x1
y1 = max(y0, y1)
assert mask.shape[0] == mask.shape[1], mask.shape
if True:
# This method is accurate but much slower.
mask = np.pad(mask, [(1, 1), (1, 1)], mode='constant')
box = _scale_box(box, float(mask.shape[0]) / (mask.shape[0] - 2))
mask_pixels = np.arange(0.0, mask.shape[0]) + 0.5
mask_continuous = interpolate.interp2d(mask_pixels, mask_pixels, mask, fill_value=0.0)
h, w = shape
ys = np.arange(0.0, h) + 0.5
xs = np.arange(0.0, w) + 0.5
ys = (ys - box[1]) / (box[3] - box[1]) * mask.shape[0]
xs = (xs - box[0]) / (box[2] - box[0]) * mask.shape[1]
res = mask_continuous(xs, ys)
return (res >= 0.5).astype('uint8')
else:
# This method (inspired by Detectron) is less accurate but fast.
# int() is floor
# box fpcoor=0.0 -> intcoor=0.0
x0, y0 = list(map(int, box[:2] + 0.5))
# box fpcoor=h -> intcoor=h-1, inclusive
x1, y1 = list(map(int, box[2:] - 0.5)) # inclusive
x1 = max(x0, x1) # require at least 1x1
y1 = max(y0, y1)
w = x1 + 1 - x0
h = y1 + 1 - y0
w = x1 + 1 - x0
h = y1 + 1 - y0
# rounding errors could happen here, because masks were not originally computed for this shape.
# but it's hard to do better, because the network does not know the "original" scale
mask = (cv2.resize(mask, (w, h)) > 0.5).astype('uint8')
ret = np.zeros(shape, dtype='uint8')
ret[y0:y1 + 1, x0:x1 + 1] = mask
return ret
# rounding errors could happen here, because masks were not originally computed for this shape.
# but it's hard to do better, because the network does not know the "original" scale
mask = (cv2.resize(mask, (w, h)) > 0.5).astype('uint8')
ret = np.zeros(shape, dtype='uint8')
ret[y0:y1 + 1, x0:x1 + 1] = mask
return ret
def predict_image(img, model_func):
......@@ -82,7 +119,6 @@ def predict_image(img, model_func):
Returns:
[DetectionResult]
"""
orig_shape = img.shape[:2]
resizer = CustomResize(cfg.PREPROC.TEST_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE)
resized_img = resizer.augment(img)
......
# Some third-party helper functions
+ generate_anchors.py: copied from [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/rpn/generate_anchors.py).
+ box_ops.py: modified from [TF object detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/core/box_list_ops.py).
+ np_box_ops.py: copied from [TF object detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/utils/np_box_ops.py).
# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/rpn/generate_anchors.py
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick and Sean Bell
# --------------------------------------------------------
import numpy as np
from six.moves import range
# Verify that we compute the same anchors as Shaoqing's matlab implementation:
#
# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat
# >> anchors
#
# anchors =
#
# -83 -39 100 56
# -175 -87 192 104
# -359 -183 376 200
# -55 -55 72 72
# -119 -119 136 136
# -247 -247 264 264
# -35 -79 52 96
# -79 -167 96 184
# -167 -343 184 360
# array([[ -83., -39., 100., 56.],
# [-175., -87., 192., 104.],
# [-359., -183., 376., 200.],
# [ -55., -55., 72., 72.],
# [-119., -119., 136., 136.],
# [-247., -247., 264., 264.],
# [ -35., -79., 52., 96.],
# [ -79., -167., 96., 184.],
# [-167., -343., 184., 360.]])
def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
scales=2**np.arange(3, 6)):
"""
Generate anchor (reference) windows by enumerating aspect ratios X
scales wrt a reference (0, 0, 15, 15) window.
"""
base_anchor = np.array([1, 1, base_size, base_size], dtype='float32') - 1
ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
for i in range(ratio_anchors.shape[0])])
return anchors
def _whctrs(anchor):
"""
Return width, height, x center, and y center for an anchor (window).
"""
w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1)
y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr):
"""
Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""
ws = ws[:, np.newaxis]
hs = hs[:, np.newaxis]
anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1),
y_ctr + 0.5 * (hs - 1)))
return anchors
def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
ws = np.round(np.sqrt(size_ratios))
hs = np.round(ws * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def _scale_enum(anchor, scales):
"""
Enumerate a set of anchors for each scale wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment