Commit 141ab53c authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use un-quantized anchors; use better postprocessing; use 1x schedule

parent ae7b0774
...@@ -54,11 +54,8 @@ Model: ...@@ -54,11 +54,8 @@ Model:
4. Because of (3), BatchNorm statistics are supposed to be freezed during fine-tuning. 4. Because of (3), BatchNorm statistics are supposed to be freezed during fine-tuning.
5. An alternative to freezing BatchNorm is to sync BatchNorm statistics across 5. An alternative to freezing BatchNorm is to sync BatchNorm statistics across
GPUs (the `BACKBONE.NORM=SyncBN` option). This would require [my bugfix](https://github.com/tensorflow/tensorflow/pull/20360) GPUs (the `BACKBONE.NORM=SyncBN` option).
which is available since TF 1.10. You can manually apply the patch to use it. Another alternative to BatchNorm is GroupNorm (`BACKBONE.NORM=GN`) which has better performance.
For now the total batch size is at most 8, so this option does not improve the model by much.
6. Another alternative to BatchNorm is GroupNorm (`BACKBONE.NORM=GN`) which has better performance.
Efficiency: Efficiency:
...@@ -74,14 +71,22 @@ Efficiency: ...@@ -74,14 +71,22 @@ Efficiency:
1. This implementation does not use specialized CUDA ops (e.g. AffineChannel, ROIAlign). 1. This implementation does not use specialized CUDA ops (e.g. AffineChannel, ROIAlign).
Therefore it might be slower than other highly-optimized implementations. Therefore it might be slower than other highly-optimized implementations.
1. To reduce RAM usage on host: (1) make sure you're using the "spawn" method as 1. To reduce RAM usage on host: (1) make sure you're using the "spawn" method as
set in `train.py`; (2) reduce `buffer_size` or `NUM_WORKERS` in `data.py` set in `train.py`; (2) reduce `buffer_size` or `NUM_WORKERS` in `data.py`
(which may negatively impact your throughput). The training needs <10G RAM if `NUM_WORKERS=0`. (which may negatively impact your throughput). The training needs <10G RAM if `NUM_WORKERS=0`.
1. Inference is unoptimized. Tensorpack is a training interface, therefore it 1. Inference is unoptimized. Tensorpack is a training interface, therefore it
does not help you on optimized inference. does not help you on optimized inference.
1. To reduce RAM usage on host: (1) make sure you're using the "spawn" method as
set in `train.py`; (2) reduce `buffer_size` or `NUM_WORKERS` in `data.py`
(which may negatively impact your throughput). The training needs <10G RAM if `NUM_WORKERS=0`.
1. Inference is unoptimized. Tensorpack is a training interface, therefore it
does not help you on optimized inference. In fact, the current implementation
uses some slow numpy operations in inference (in `eval.py:_paste_mask`).
Possible Future Enhancements: Possible Future Enhancements:
1. Define a better interface to load different datasets. 1. Define a better interface to load different datasets.
......
...@@ -82,42 +82,38 @@ All models are fine-tuned from ImageNet pre-trained R50/R101 models in ...@@ -82,42 +82,38 @@ All models are fine-tuned from ImageNet pre-trained R50/R101 models in
[tensorpack model zoo](http://models.tensorpack.com/FasterRCNN/), unless otherwise noted. [tensorpack model zoo](http://models.tensorpack.com/FasterRCNN/), unless otherwise noted.
All models are trained with 8 NVIDIA V100s, unless otherwise noted. All models are trained with 8 NVIDIA V100s, unless otherwise noted.
Performance in [Detectron](https://github.com/facebookresearch/Detectron/) can Performance in [Detectron](https://github.com/facebookresearch/Detectron/) can be reproduced.
be approximately reproduced.
| Backbone | mAP<br/>(box;mask) | Detectron mAP <sup>[1](#ft1)</sup><br/> (box;mask) | Time <br/>(on 8 V100s) | Configurations <br/> (click to expand) |
| Backbone | mAP<br/>(box;mask) | Detectron mAP <sup>[1](#ft1)</sup><br/> (box;mask) | Time <br/>(on 8 V100s) | Configurations <br/> (click to expand) | | - | - | - | - | - |
| - | - | - | - | - | | R50-C4 | 34.1 | | 7.5h | <details><summary>super quick</summary>`MODE_MASK=False FRCNN.BATCH_PER_IM=64`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=600 PREPROC.MAX_SIZE=1024`<br/>`TRAIN.LR_SCHEDULE=[140000,180000,200000]` </details> |
| R50-C4 | 33.5 | | 17h | <details><summary>super quick</summary>`MODE_MASK=False FRCNN.BATCH_PER_IM=64`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=600 PREPROC.MAX_SIZE=1024`<br/>`TRAIN.LR_SCHEDULE=[150000,230000,280000]` </details> | | R50-C4 | 35.6 | 34.8 | 23h | <details><summary>standard</summary>`MODE_MASK=False` </details> |
| R50-C4 | 36.6 | 36.5 | 44h | <details><summary>standard</summary>`MODE_MASK=False` </details> | | R50-FPN | 37.5 | 36.7 | 11h | <details><summary>standard</summary>`MODE_MASK=False MODE_FPN=True` </details> |
| R50-FPN | 37.4 | 37.9 | 23h | <details><summary>standard</summary>`MODE_MASK=False MODE_FPN=True` </details> | | R50-C4 | 36.2;31.8 [:arrow_down:][R50C41x] | 35.8;31.4 | 23.5h | <details><summary>standard</summary>this is the default </details> |
| R50-C4 | 38.2;33.3 [:arrow_down:][R50C42x] | 37.8;32.8 | 49h | <details><summary>standard</summary>this is the default </details> | | R50-FPN | 38.2;34.8 | 37.7;33.9 | 13.5h | <details><summary>standard</summary>`MODE_FPN=True` </details> |
| R50-FPN | 38.4;35.1 [:arrow_down:][R50FPN2x] | 38.6;34.5 | 27h | <details><summary>standard</summary>`MODE_FPN=True` </details> | | R50-FPN | 38.9;35.4 [:arrow_down:][R50FPN2x] | 38.6;34.5 | 25h | <details><summary>2x</summary>`MODE_FPN=True`<br/>`TRAIN.LR_SCHEDULE=[240000,320000,360000]` </details> |
| R50-FPN | 42.0;36.3 | | 36h | <details><summary>+Cascade</summary>`MODE_FPN=True FPN.CASCADE=True` </details> | | R50-FPN-GN | 40.4;36.3 [:arrow_down:][R50FPN2xGN] | 40.3;35.7 | 31h | <details><summary>2x+GN</summary>`MODE_FPN=True`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head` |
| R50-FPN | 39.5;35.2 | 39.5;34.4<sup>[2](#ft2)</sup> | 31h | <details><summary>+ConvGNHead</summary>`MODE_FPN=True`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head` </details> | | R50-FPN | 41.7;36.2 | | 17h | <details><summary>+Cascade</summary>`MODE_FPN=True FPN.CASCADE=True` </details> |
| R50-FPN | 40.0;36.2 [:arrow_down:][R50FPN2xGN] | 40.3;35.7 | 33h | <details><summary>+GN</summary>`MODE_FPN=True`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head` | | R101-C4 | 40.1;34.6 [:arrow_down:][R101C41x] | | 28h | <details><summary>standard</summary>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> |
| R101-C4 | 41.4;35.2 [:arrow_down:][R101C42x] | | 60h | <details><summary>standard</summary>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> | | R101-FPN | 40.7;36.8 [:arrow_down:][R101FPN1x] | 40.0;35.9 | 18h | <details><summary>standard</summary>`MODE_FPN=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> |
| R101-FPN | 40.4;36.6 [:arrow_down:][R101FPN2x] | 40.9;36.4 | 37h | <details><summary>standard</summary>`MODE_FPN=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> | | R101-FPN | 46.6;40.3 [:arrow_down:][R101FPN3xCasAug] <sup>[2](#ft2)</sup> | | 69h | <details><summary>3x+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`TEST.RESULT_SCORE_THRESH=1e-4`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=[420000,500000,540000]` </details> |
| R101-FPN | 46.5;40.1 [:arrow_down:][R101FPN3xCasAug] <sup>[3](#ft3)</sup> | | 73h | <details><summary>3x+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`TEST.RESULT_SCORE_THRESH=1e-4`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=[420000,500000,540000]` </details> | | R101-FPN-GN<br/>(From Scratch) | 47.7;41.7 [:arrow_down:][R101FPN9xGNCasAugScratch]<sup>[3](#ft3)</sup> | 47.4;40.5 | 28h (on 64 V100s) | <details><summary>9x+GN+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=[1500000,1580000,1620000]`<br/>`BACKBONE.FREEZE_AT=0`</details> |
| R101-FPN<br/>(From Scratch) | 47.5;41.2 [:arrow_down:][R101FPN9xGNCasAugScratch] | 47.4;40.5<sup>[4](#ft4)</sup> | 45h <br/>(on 48 V100s) | <details><summary>9x+GN+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=[1500000,1580000,1620000]`<br/>`BACKBONE.FREEZE_AT=0`</details> |
[R50C41x]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R50C41x.npz
[R50C42x]: http://models.tensorpack.com/FasterRCNN/COCO-R50C4-MaskRCNN-Standard.npz [R50FPN2x]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R50FPN2x.npz
[R50FPN2x]: http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-Standard.npz [R50FPN2xGN]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R50FPN2xGN.npz
[R50FPN2xGN]: http://models.tensorpack.com/FasterRCNN/COCO-R50FPN-MaskRCNN-StandardGN.npz [R101C41x]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R101C41x.npz
[R101C42x]: http://models.tensorpack.com/FasterRCNN/COCO-R101C4-MaskRCNN-Standard.npz [R101FPN1x]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R101FPN1x.npz
[R101FPN2x]: http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-Standard.npz [R101FPN3xCasAug]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R101FPN3xCasAug.npz
[R101FPN3xCasAug]: http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-BetterParams.npz [R101FPN9xGNCasAugScratch]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R101FPN9xGNCasAugScratch.npz
[R101FPN9xGNCasAugScratch]: http://models.tensorpack.com/FasterRCNN/COCO-R101FPN-MaskRCNN-ScratchGN.npz
<a id="ft1">1</a>: Numbers taken from [Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md). <a id="ft1">1</a>: Numbers taken from [Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md).
We compare models that have identical training & inference cost between the two implementations. Their numbers can be different due to many small implementation details. We compare models that have identical training & inference cost between the two implementations.
For example, our FPN models are sometimes slightly worse in box AP, which is Their numbers can be different due to small implementation details.
mainly due to batch size.
<a id="ft2">2</a>: Numbers taken from Table 5 in [Group Normalization](https://arxiv.org/abs/1803.08494) <a id="ft2">2</a>: Our mAP is __10+ point__ better than the official model in [matterport/Mask_RCNN](https://github.com/matterport/Mask_RCNN/releases/tag/v2.0) with the same R101-FPN backbone.
<a id="ft3">3</a>: Our mAP is __10+ point__ better than the official model in [matterport/Mask_RCNN](https://github.com/matterport/Mask_RCNN/releases/tag/v2.0) with the same R101-FPN backbone. <a id="ft3">3</a>: This entry does not use ImageNet pre-training. Detectron numbers are taken from Fig. 5 in [Rethinking ImageNet Pre-training](https://arxiv.org/abs/1811.08883).
<a id="ft4">4</a>: This entry does not use ImageNet pre-training. Detectron numbers are taken from Fig. 5 in [Rethinking ImageNet Pre-training](https://arxiv.org/abs/1811.08883).
Note that our training strategy is slightly different: we enable cascade throughout the entire training. Note that our training strategy is slightly different: we enable cascade throughout the entire training.
As far as I know, this model is the __best open source model__ on COCO dataset. As far as I know, this model is the __best open source model__ on COCO dataset.
......
...@@ -130,8 +130,8 @@ _C.TRAIN.STARTING_EPOCH = 1 # the first epoch to start with, useful to continue ...@@ -130,8 +130,8 @@ _C.TRAIN.STARTING_EPOCH = 1 # the first epoch to start with, useful to continue
# the base learning rate are computed from BASE_LR and LR_SCHEDULE. # the base learning rate are computed from BASE_LR and LR_SCHEDULE.
# Therefore, there is *no need* to modify the config if you only change the number of GPUs. # Therefore, there is *no need* to modify the config if you only change the number of GPUs.
# _C.TRAIN.LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron _C.TRAIN.LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
_C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron # _C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
# Longer schedules for from-scratch training (https://arxiv.org/abs/1811.08883): # Longer schedules for from-scratch training (https://arxiv.org/abs/1811.08883):
# _C.TRAIN.LR_SCHEDULE = [960000, 1040000, 1080000] # "6x" schedule in detectron # _C.TRAIN.LR_SCHEDULE = [960000, 1040000, 1080000] # "6x" schedule in detectron
# _C.TRAIN.LR_SCHEDULE = [1500000, 1580000, 1620000] # "9x" schedule in detectron # _C.TRAIN.LR_SCHEDULE = [1500000, 1580000, 1620000] # "9x" schedule in detectron
......
...@@ -9,17 +9,18 @@ from tabulate import tabulate ...@@ -9,17 +9,18 @@ from tabulate import tabulate
from termcolor import colored from termcolor import colored
from tensorpack.dataflow import ( from tensorpack.dataflow import (
DataFromList, MapData, MapDataComponent, MultiProcessMapData, MultiThreadMapData, DataFromList, MapData, MapDataComponent,
TestDataSpeed, imgaug) MultiProcessMapData, MultiThreadMapData, TestDataSpeed, imgaug,
)
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.argtools import log_once, memoized from tensorpack.utils.argtools import log_once, memoized
from common import ( from common import (
CustomResize, DataFromListOfDict, box_to_point8, filter_boxes_inside_shape, np_iou, CustomResize, DataFromListOfDict, box_to_point8,
point8_to_box, segmentation_to_mask) filter_boxes_inside_shape, np_iou, point8_to_box, segmentation_to_mask,
)
from config import config as cfg from config import config as cfg
from dataset import DatasetRegistry from dataset import DatasetRegistry
from utils.generate_anchors import generate_anchors
from utils.np_box_ops import area as np_area from utils.np_box_ops import area as np_area
from utils.np_box_ops import ioa as np_ioa from utils.np_box_ops import ioa as np_ioa
...@@ -42,15 +43,14 @@ def print_class_histogram(roidbs): ...@@ -42,15 +43,14 @@ def print_class_histogram(roidbs):
gt_hist = np.zeros((cfg.DATA.NUM_CATEGORY + 1,), dtype=np.int) gt_hist = np.zeros((cfg.DATA.NUM_CATEGORY + 1,), dtype=np.int)
for entry in roidbs: for entry in roidbs:
# filter crowd? # filter crowd?
gt_inds = np.where( gt_inds = np.where((entry["class"] > 0) & (entry["is_crowd"] == 0))[0]
(entry['class'] > 0) & (entry['is_crowd'] == 0))[0] gt_classes = entry["class"][gt_inds]
gt_classes = entry['class'][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0] gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = [[cfg.DATA.CLASS_NAMES[i], v] for i, v in enumerate(gt_hist)] data = [[cfg.DATA.CLASS_NAMES[i], v] for i, v in enumerate(gt_hist)]
data.append(['total', sum(x[1] for x in data)]) data.append(["total", sum(x[1] for x in data)])
# the first line is BG # the first line is BG
table = tabulate(data[1:], headers=['class', '#box'], tablefmt='pipe') table = tabulate(data[1:], headers=["class", "#box"], tablefmt="pipe")
logger.info("Ground-Truth Boxes:\n" + colored(table, 'cyan')) logger.info("Ground-Truth Boxes:\n" + colored(table, "cyan"))
@memoized @memoized
...@@ -69,17 +69,17 @@ def get_all_anchors(*, stride, sizes, ratios, max_size): ...@@ -69,17 +69,17 @@ def get_all_anchors(*, stride, sizes, ratios, max_size):
""" """
# Generates a NAx4 matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors # Generates a NAx4 matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors
# are centered on stride / 2, have (approximate) sqrt areas of the specified # are centered on 0, have sqrt areas equal to the specified sizes, and aspect ratios as given.
# sizes, and aspect ratios as given. anchors = []
cell_anchors = generate_anchors( for sz in sizes:
stride, for ratio in ratios:
scales=np.array(sizes, dtype=np.float) / stride, w = np.sqrt(sz * sz / ratio)
ratios=np.array(ratios, dtype=np.float)) h = ratio * w
# anchors are intbox here. anchors.append([-w, -h, w, h])
# anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride) cell_anchors = np.asarray(anchors) * 0.5
field_size = int(np.ceil(max_size / stride)) field_size = int(np.ceil(max_size / stride))
shifts = np.arange(0, field_size) * stride shifts = (np.arange(0, field_size) * stride).astype("float32")
shift_x, shift_y = np.meshgrid(shifts, shifts) shift_x, shift_y = np.meshgrid(shifts, shifts)
shift_x = shift_x.flatten() shift_x = shift_x.flatten()
shift_y = shift_y.flatten() shift_y = shift_y.flatten()
...@@ -88,15 +88,12 @@ def get_all_anchors(*, stride, sizes, ratios, max_size): ...@@ -88,15 +88,12 @@ def get_all_anchors(*, stride, sizes, ratios, max_size):
K = shifts.shape[0] K = shifts.shape[0]
A = cell_anchors.shape[0] A = cell_anchors.shape[0]
field_of_anchors = ( field_of_anchors = cell_anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2))
cell_anchors.reshape((1, A, 4)) +
shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
field_of_anchors = field_of_anchors.reshape((field_size, field_size, A, 4)) field_of_anchors = field_of_anchors.reshape((field_size, field_size, A, 4))
# FSxFSxAx4 # FSxFSxAx4
# Many rounding happens inside the anchor code anyway # Many rounding happens inside the anchor code anyway
# assert np.all(field_of_anchors == field_of_anchors.astype('int32')) # assert np.all(field_of_anchors == field_of_anchors.astype('int32'))
field_of_anchors = field_of_anchors.astype('float32') field_of_anchors = field_of_anchors.astype("float32")
field_of_anchors[:, :, :, [2, 3]] += 1
return field_of_anchors return field_of_anchors
...@@ -121,18 +118,19 @@ class TrainingDataPreprocessor: ...@@ -121,18 +118,19 @@ class TrainingDataPreprocessor:
Since the mapping may run in other processes, we write a new class and Since the mapping may run in other processes, we write a new class and
explicitly pass cfg to it, in the spirit of "explicitly pass resources to subprocess". explicitly pass cfg to it, in the spirit of "explicitly pass resources to subprocess".
""" """
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self.aug = imgaug.AugmentorList( self.aug = imgaug.AugmentorList(
[CustomResize(cfg.PREPROC.TRAIN_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE), [CustomResize(cfg.PREPROC.TRAIN_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE), imgaug.Flip(horiz=True)]
imgaug.Flip(horiz=True)]) )
def __call__(self, roidb): def __call__(self, roidb):
fname, boxes, klass, is_crowd = roidb['file_name'], roidb['boxes'], roidb['class'], roidb['is_crowd'] fname, boxes, klass, is_crowd = roidb["file_name"], roidb["boxes"], roidb["class"], roidb["is_crowd"]
boxes = np.copy(boxes) boxes = np.copy(boxes)
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
im = im.astype('float32') im = im.astype("float32")
height, width = im.shape[:2] height, width = im.shape[:2]
# assume floatbox as input # assume floatbox as input
assert boxes.dtype == np.float32, "Loader has to return floating point boxes!" assert boxes.dtype == np.float32, "Loader has to return floating point boxes!"
...@@ -148,30 +146,30 @@ class TrainingDataPreprocessor: ...@@ -148,30 +146,30 @@ class TrainingDataPreprocessor:
boxes = point8_to_box(points) boxes = point8_to_box(points)
assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!" assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!"
ret = {'image': im} ret = {"image": im}
# Add rpn data to dataflow: # Add rpn data to dataflow:
try: try:
if self.cfg.MODE_FPN: if self.cfg.MODE_FPN:
multilevel_anchor_inputs = self.get_multilevel_rpn_anchor_input(im, boxes, is_crowd) multilevel_anchor_inputs = self.get_multilevel_rpn_anchor_input(im, boxes, is_crowd)
for i, (anchor_labels, anchor_boxes) in enumerate(multilevel_anchor_inputs): for i, (anchor_labels, anchor_boxes) in enumerate(multilevel_anchor_inputs):
ret['anchor_labels_lvl{}'.format(i + 2)] = anchor_labels ret["anchor_labels_lvl{}".format(i + 2)] = anchor_labels
ret['anchor_boxes_lvl{}'.format(i + 2)] = anchor_boxes ret["anchor_boxes_lvl{}".format(i + 2)] = anchor_boxes
else: else:
ret['anchor_labels'], ret['anchor_boxes'] = self.get_rpn_anchor_input(im, boxes, is_crowd) ret["anchor_labels"], ret["anchor_boxes"] = self.get_rpn_anchor_input(im, boxes, is_crowd)
boxes = boxes[is_crowd == 0] # skip crowd boxes in training target boxes = boxes[is_crowd == 0] # skip crowd boxes in training target
klass = klass[is_crowd == 0] klass = klass[is_crowd == 0]
ret['gt_boxes'] = boxes ret["gt_boxes"] = boxes
ret['gt_labels'] = klass ret["gt_labels"] = klass
if not len(boxes): if not len(boxes):
raise MalformedData("No valid gt_boxes!") raise MalformedData("No valid gt_boxes!")
except MalformedData as e: except MalformedData as e:
log_once("Input {} is filtered for training: {}".format(fname, str(e)), 'warn') log_once("Input {} is filtered for training: {}".format(fname, str(e)), "warn")
return None return None
if self.cfg.MODE_MASK: if self.cfg.MODE_MASK:
# augmentation will modify the polys in-place # augmentation will modify the polys in-place
segmentation = copy.deepcopy(roidb['segmentation']) segmentation = copy.deepcopy(roidb["segmentation"])
segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]] segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]]
assert len(segmentation) == len(boxes) assert len(segmentation) == len(boxes)
...@@ -210,11 +208,14 @@ class TrainingDataPreprocessor: ...@@ -210,11 +208,14 @@ class TrainingDataPreprocessor:
NA will be NUM_ANCHOR_SIZES x NUM_ANCHOR_RATIOS NA will be NUM_ANCHOR_SIZES x NUM_ANCHOR_RATIOS
""" """
boxes = boxes.copy() boxes = boxes.copy()
all_anchors = np.copy(get_all_anchors( all_anchors = np.copy(
stride=self.cfg.RPN.ANCHOR_STRIDE, get_all_anchors(
sizes=self.cfg.RPN.ANCHOR_SIZES, stride=self.cfg.RPN.ANCHOR_STRIDE,
ratios=self.cfg.RPN.ANCHOR_RATIOS, sizes=self.cfg.RPN.ANCHOR_SIZES,
max_size=self.cfg.PREPROC.MAX_SIZE)) ratios=self.cfg.RPN.ANCHOR_RATIOS,
max_size=self.cfg.PREPROC.MAX_SIZE,
)
)
# fHxfWxAx4 -> (-1, 4) # fHxfWxAx4 -> (-1, 4)
featuremap_anchors_flatten = all_anchors.reshape((-1, 4)) featuremap_anchors_flatten = all_anchors.reshape((-1, 4))
...@@ -222,15 +223,16 @@ class TrainingDataPreprocessor: ...@@ -222,15 +223,16 @@ class TrainingDataPreprocessor:
inside_ind, inside_anchors = filter_boxes_inside_shape(featuremap_anchors_flatten, im.shape[:2]) inside_ind, inside_anchors = filter_boxes_inside_shape(featuremap_anchors_flatten, im.shape[:2])
# obtain anchor labels and their corresponding gt boxes # obtain anchor labels and their corresponding gt boxes
anchor_labels, anchor_gt_boxes = self.get_anchor_labels( anchor_labels, anchor_gt_boxes = self.get_anchor_labels(
inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1]) inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1]
)
# Fill them back to original size: fHxfWx1, fHxfWx4 # Fill them back to original size: fHxfWx1, fHxfWx4
num_anchor = self.cfg.RPN.NUM_ANCHOR num_anchor = self.cfg.RPN.NUM_ANCHOR
anchorH, anchorW = all_anchors.shape[:2] anchorH, anchorW = all_anchors.shape[:2]
featuremap_labels = -np.ones((anchorH * anchorW * num_anchor, ), dtype='int32') featuremap_labels = -np.ones((anchorH * anchorW * num_anchor,), dtype="int32")
featuremap_labels[inside_ind] = anchor_labels featuremap_labels[inside_ind] = anchor_labels
featuremap_labels = featuremap_labels.reshape((anchorH, anchorW, num_anchor)) featuremap_labels = featuremap_labels.reshape((anchorH, anchorW, num_anchor))
featuremap_boxes = np.zeros((anchorH * anchorW * num_anchor, 4), dtype='float32') featuremap_boxes = np.zeros((anchorH * anchorW * num_anchor, 4), dtype="float32")
featuremap_boxes[inside_ind, :] = anchor_gt_boxes featuremap_boxes[inside_ind, :] = anchor_gt_boxes
featuremap_boxes = featuremap_boxes.reshape((anchorH, anchorW, num_anchor, 4)) featuremap_boxes = featuremap_boxes.reshape((anchorH, anchorW, num_anchor, 4))
return featuremap_labels, featuremap_boxes return featuremap_labels, featuremap_boxes
...@@ -254,32 +256,33 @@ class TrainingDataPreprocessor: ...@@ -254,32 +256,33 @@ class TrainingDataPreprocessor:
strides=self.cfg.FPN.ANCHOR_STRIDES, strides=self.cfg.FPN.ANCHOR_STRIDES,
sizes=self.cfg.RPN.ANCHOR_SIZES, sizes=self.cfg.RPN.ANCHOR_SIZES,
ratios=self.cfg.RPN.ANCHOR_RATIOS, ratios=self.cfg.RPN.ANCHOR_RATIOS,
max_size=self.cfg.PREPROC.MAX_SIZE) max_size=self.cfg.PREPROC.MAX_SIZE,
)
flatten_anchors_per_level = [k.reshape((-1, 4)) for k in anchors_per_level] flatten_anchors_per_level = [k.reshape((-1, 4)) for k in anchors_per_level]
all_anchors_flatten = np.concatenate(flatten_anchors_per_level, axis=0) all_anchors_flatten = np.concatenate(flatten_anchors_per_level, axis=0)
inside_ind, inside_anchors = filter_boxes_inside_shape(all_anchors_flatten, im.shape[:2]) inside_ind, inside_anchors = filter_boxes_inside_shape(all_anchors_flatten, im.shape[:2])
anchor_labels, anchor_gt_boxes = self.get_anchor_labels( anchor_labels, anchor_gt_boxes = self.get_anchor_labels(
inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1]) inside_anchors, boxes[is_crowd == 0], boxes[is_crowd == 1]
)
# map back to all_anchors, then split to each level # map back to all_anchors, then split to each level
num_all_anchors = all_anchors_flatten.shape[0] num_all_anchors = all_anchors_flatten.shape[0]
all_labels = -np.ones((num_all_anchors, ), dtype='int32') all_labels = -np.ones((num_all_anchors,), dtype="int32")
all_labels[inside_ind] = anchor_labels all_labels[inside_ind] = anchor_labels
all_boxes = np.zeros((num_all_anchors, 4), dtype='float32') all_boxes = np.zeros((num_all_anchors, 4), dtype="float32")
all_boxes[inside_ind] = anchor_gt_boxes all_boxes[inside_ind] = anchor_gt_boxes
start = 0 start = 0
multilevel_inputs = [] multilevel_inputs = []
for level_anchor in anchors_per_level: for level_anchor in anchors_per_level:
assert level_anchor.shape[2] == len(self.cfg.RPN.ANCHOR_RATIOS) assert level_anchor.shape[2] == len(self.cfg.RPN.ANCHOR_RATIOS)
anchor_shape = level_anchor.shape[:3] # fHxfWxNUM_ANCHOR_RATIOS anchor_shape = level_anchor.shape[:3] # fHxfWxNUM_ANCHOR_RATIOS
num_anchor_this_level = np.prod(anchor_shape) num_anchor_this_level = np.prod(anchor_shape)
end = start + num_anchor_this_level end = start + num_anchor_this_level
multilevel_inputs.append( multilevel_inputs.append(
(all_labels[start: end].reshape(anchor_shape), (all_labels[start:end].reshape(anchor_shape), all_boxes[start:end, :].reshape(anchor_shape + (4,)))
all_boxes[start: end, :].reshape(anchor_shape + (4,)) )
))
start = end start = end
assert end == num_all_anchors, "{} != {}".format(end, num_all_anchors) assert end == num_all_anchors, "{} != {}".format(end, num_all_anchors)
return multilevel_inputs return multilevel_inputs
...@@ -300,10 +303,8 @@ class TrainingDataPreprocessor: ...@@ -300,10 +303,8 @@ class TrainingDataPreprocessor:
def filter_box_label(labels, value, max_num): def filter_box_label(labels, value, max_num):
curr_inds = np.where(labels == value)[0] curr_inds = np.where(labels == value)[0]
if len(curr_inds) > max_num: if len(curr_inds) > max_num:
disable_inds = np.random.choice( disable_inds = np.random.choice(curr_inds, size=(len(curr_inds) - max_num), replace=False)
curr_inds, size=(len(curr_inds) - max_num), labels[disable_inds] = -1 # ignore them
replace=False)
labels[disable_inds] = -1 # ignore them
curr_inds = np.where(labels == value)[0] curr_inds = np.where(labels == value)[0]
return curr_inds return curr_inds
...@@ -317,7 +318,7 @@ class TrainingDataPreprocessor: ...@@ -317,7 +318,7 @@ class TrainingDataPreprocessor:
anchors_with_max_iou_per_gt = np.where(box_ious == ious_max_per_gt)[0] anchors_with_max_iou_per_gt = np.where(box_ious == ious_max_per_gt)[0]
# Setting NA labels: 1--fg 0--bg -1--ignore # Setting NA labels: 1--fg 0--bg -1--ignore
anchor_labels = -np.ones((NA,), dtype='int32') # NA, anchor_labels = -np.ones((NA,), dtype="int32") # NA,
# the order of setting neg/pos labels matter # the order of setting neg/pos labels matter
anchor_labels[anchors_with_max_iou_per_gt] = 1 anchor_labels[anchors_with_max_iou_per_gt] = 1
...@@ -345,10 +346,10 @@ class TrainingDataPreprocessor: ...@@ -345,10 +346,10 @@ class TrainingDataPreprocessor:
# No valid bg in this image, skip. # No valid bg in this image, skip.
raise MalformedData("No valid background for RPN!") raise MalformedData("No valid background for RPN!")
target_num_bg = self.cfg.RPN.BATCH_PER_IM - len(fg_inds) target_num_bg = self.cfg.RPN.BATCH_PER_IM - len(fg_inds)
filter_box_label(anchor_labels, 0, target_num_bg) # ignore return values filter_box_label(anchor_labels, 0, target_num_bg) # ignore return values
# Set anchor boxes: the best gt_box for each fg anchor # Set anchor boxes: the best gt_box for each fg anchor
anchor_boxes = np.zeros((NA, 4), dtype='float32') anchor_boxes = np.zeros((NA, 4), dtype="float32")
fg_boxes = gt_boxes[ious_argmax_per_anchor[fg_inds], :] fg_boxes = gt_boxes[ious_argmax_per_anchor[fg_inds], :]
anchor_boxes[fg_inds, :] = fg_boxes anchor_boxes[fg_inds, :] = fg_boxes
# assert len(fg_inds) + np.sum(anchor_labels == 0) == self.cfg.RPN.BATCH_PER_IM # assert len(fg_inds) + np.sum(anchor_labels == 0) == self.cfg.RPN.BATCH_PER_IM
...@@ -377,16 +378,19 @@ def get_train_dataflow(): ...@@ -377,16 +378,19 @@ def get_train_dataflow():
# Valid training images should have at least one fg box. # Valid training images should have at least one fg box.
# But this filter shall not be applied for testing. # But this filter shall not be applied for testing.
num = len(roidbs) num = len(roidbs)
roidbs = list(filter(lambda img: len(img['boxes'][img['is_crowd'] == 0]) > 0, roidbs)) roidbs = list(filter(lambda img: len(img["boxes"][img["is_crowd"] == 0]) > 0, roidbs))
logger.info("Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format( logger.info(
num - len(roidbs), len(roidbs))) "Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format(
num - len(roidbs), len(roidbs)
)
)
ds = DataFromList(roidbs, shuffle=True) ds = DataFromList(roidbs, shuffle=True)
preprocess = TrainingDataPreprocessor(cfg) preprocess = TrainingDataPreprocessor(cfg)
if cfg.DATA.NUM_WORKERS > 0: if cfg.DATA.NUM_WORKERS > 0:
if cfg.TRAINER == 'horovod': if cfg.TRAINER == "horovod":
buffer_size = cfg.DATA.NUM_WORKERS * 10 # one dataflow for each process, therefore don't need large buffer buffer_size = cfg.DATA.NUM_WORKERS * 10 # one dataflow for each process, therefore don't need large buffer
ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size) ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
# MPI does not like fork() # MPI does not like fork()
...@@ -412,21 +416,23 @@ def get_eval_dataflow(name, shard=0, num_shards=1): ...@@ -412,21 +416,23 @@ def get_eval_dataflow(name, shard=0, num_shards=1):
img_range = (shard * img_per_shard, (shard + 1) * img_per_shard if shard + 1 < num_shards else num_imgs) img_range = (shard * img_per_shard, (shard + 1) * img_per_shard if shard + 1 < num_shards else num_imgs)
# no filter for training # no filter for training
ds = DataFromListOfDict(roidbs[img_range[0]: img_range[1]], ['file_name', 'image_id']) ds = DataFromListOfDict(roidbs[img_range[0]: img_range[1]], ["file_name", "image_id"])
def f(fname): def f(fname):
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
return im return im
ds = MapDataComponent(ds, f, 0) ds = MapDataComponent(ds, f, 0)
# Evaluation itself may be multi-threaded, therefore don't add prefetch here. # Evaluation itself may be multi-threaded, therefore don't add prefetch here.
return ds return ds
if __name__ == '__main__': if __name__ == "__main__":
import os import os
from tensorpack.dataflow import PrintData from tensorpack.dataflow import PrintData
cfg.DATA.BASEDIR = os.path.expanduser('~/data/coco')
cfg.DATA.BASEDIR = os.path.expanduser("~/data/coco")
ds = get_train_dataflow() ds = get_train_dataflow()
ds = PrintData(ds, 100) ds = PrintData(ds, 100)
TestDataSpeed(ds, 50000).start() TestDataSpeed(ds, 50000).start()
......
...@@ -13,6 +13,7 @@ from contextlib import ExitStack ...@@ -13,6 +13,7 @@ from contextlib import ExitStack
import cv2 import cv2
import pycocotools.mask as cocomask import pycocotools.mask as cocomask
import tqdm import tqdm
from scipy import interpolate
from tensorpack.callbacks import Callback from tensorpack.callbacks import Callback
from tensorpack.tfutils.common import get_tf_version_tuple from tensorpack.tfutils.common import get_tf_version_tuple
...@@ -41,6 +42,23 @@ mask: None, or a binary image of the original image shape ...@@ -41,6 +42,23 @@ mask: None, or a binary image of the original image shape
""" """
def _scale_box(box, scale):
w_half = (box[2] - box[0]) * 0.5
h_half = (box[3] - box[1]) * 0.5
x_c = (box[2] + box[0]) * 0.5
y_c = (box[3] + box[1]) * 0.5
w_half *= scale
h_half *= scale
scaled_box = np.zeros_like(box)
scaled_box[0] = x_c - w_half
scaled_box[2] = x_c + w_half
scaled_box[1] = y_c - h_half
scaled_box[3] = y_c + h_half
return scaled_box
def _paste_mask(box, mask, shape): def _paste_mask(box, mask, shape):
""" """
Args: Args:
...@@ -50,23 +68,42 @@ def _paste_mask(box, mask, shape): ...@@ -50,23 +68,42 @@ def _paste_mask(box, mask, shape):
Returns: Returns:
A uint8 binary image of hxw. A uint8 binary image of hxw.
""" """
# int() is floor assert mask.shape[0] == mask.shape[1], mask.shape
# box fpcoor=0.0 -> intcoor=0.0
x0, y0 = list(map(int, box[:2] + 0.5)) if True:
# box fpcoor=h -> intcoor=h-1, inclusive # This method is accurate but much slower.
x1, y1 = list(map(int, box[2:] - 0.5)) # inclusive mask = np.pad(mask, [(1, 1), (1, 1)], mode='constant')
x1 = max(x0, x1) # require at least 1x1 box = _scale_box(box, float(mask.shape[0]) / (mask.shape[0] - 2))
y1 = max(y0, y1)
mask_pixels = np.arange(0.0, mask.shape[0]) + 0.5
mask_continuous = interpolate.interp2d(mask_pixels, mask_pixels, mask, fill_value=0.0)
h, w = shape
ys = np.arange(0.0, h) + 0.5
xs = np.arange(0.0, w) + 0.5
ys = (ys - box[1]) / (box[3] - box[1]) * mask.shape[0]
xs = (xs - box[0]) / (box[2] - box[0]) * mask.shape[1]
res = mask_continuous(xs, ys)
return (res >= 0.5).astype('uint8')
else:
# This method (inspired by Detectron) is less accurate but fast.
# int() is floor
# box fpcoor=0.0 -> intcoor=0.0
x0, y0 = list(map(int, box[:2] + 0.5))
# box fpcoor=h -> intcoor=h-1, inclusive
x1, y1 = list(map(int, box[2:] - 0.5)) # inclusive
x1 = max(x0, x1) # require at least 1x1
y1 = max(y0, y1)
w = x1 + 1 - x0 w = x1 + 1 - x0
h = y1 + 1 - y0 h = y1 + 1 - y0
# rounding errors could happen here, because masks were not originally computed for this shape. # rounding errors could happen here, because masks were not originally computed for this shape.
# but it's hard to do better, because the network does not know the "original" scale # but it's hard to do better, because the network does not know the "original" scale
mask = (cv2.resize(mask, (w, h)) > 0.5).astype('uint8') mask = (cv2.resize(mask, (w, h)) > 0.5).astype('uint8')
ret = np.zeros(shape, dtype='uint8') ret = np.zeros(shape, dtype='uint8')
ret[y0:y1 + 1, x0:x1 + 1] = mask ret[y0:y1 + 1, x0:x1 + 1] = mask
return ret return ret
def predict_image(img, model_func): def predict_image(img, model_func):
...@@ -82,7 +119,6 @@ def predict_image(img, model_func): ...@@ -82,7 +119,6 @@ def predict_image(img, model_func):
Returns: Returns:
[DetectionResult] [DetectionResult]
""" """
orig_shape = img.shape[:2] orig_shape = img.shape[:2]
resizer = CustomResize(cfg.PREPROC.TEST_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE) resizer = CustomResize(cfg.PREPROC.TEST_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE)
resized_img = resizer.augment(img) resized_img = resizer.augment(img)
......
# Some third-party helper functions # Some third-party helper functions
+ generate_anchors.py: copied from [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/rpn/generate_anchors.py).
+ box_ops.py: modified from [TF object detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/core/box_list_ops.py). + box_ops.py: modified from [TF object detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/core/box_list_ops.py).
+ np_box_ops.py: copied from [TF object detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/utils/np_box_ops.py). + np_box_ops.py: copied from [TF object detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/utils/np_box_ops.py).
# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/rpn/generate_anchors.py
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick and Sean Bell
# --------------------------------------------------------
import numpy as np
from six.moves import range
# Verify that we compute the same anchors as Shaoqing's matlab implementation:
#
# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat
# >> anchors
#
# anchors =
#
# -83 -39 100 56
# -175 -87 192 104
# -359 -183 376 200
# -55 -55 72 72
# -119 -119 136 136
# -247 -247 264 264
# -35 -79 52 96
# -79 -167 96 184
# -167 -343 184 360
# array([[ -83., -39., 100., 56.],
# [-175., -87., 192., 104.],
# [-359., -183., 376., 200.],
# [ -55., -55., 72., 72.],
# [-119., -119., 136., 136.],
# [-247., -247., 264., 264.],
# [ -35., -79., 52., 96.],
# [ -79., -167., 96., 184.],
# [-167., -343., 184., 360.]])
def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
scales=2**np.arange(3, 6)):
"""
Generate anchor (reference) windows by enumerating aspect ratios X
scales wrt a reference (0, 0, 15, 15) window.
"""
base_anchor = np.array([1, 1, base_size, base_size], dtype='float32') - 1
ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
for i in range(ratio_anchors.shape[0])])
return anchors
def _whctrs(anchor):
"""
Return width, height, x center, and y center for an anchor (window).
"""
w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1)
y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr):
"""
Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""
ws = ws[:, np.newaxis]
hs = hs[:, np.newaxis]
anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1),
y_ctr + 0.5 * (hs - 1)))
return anchors
def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
ws = np.round(np.sqrt(size_ratios))
hs = np.round(ws * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def _scale_enum(anchor, scales):
"""
Enumerate a set of anchors for each scale wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment