Commit 843d44e9 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] train-time scale augmentation

parent 2efc98f7
......@@ -32,22 +32,27 @@ class CustomResize(transform.TransformAugmentorBase):
while avoiding the longest edge to exceed max_size.
"""
def __init__(self, size, max_size, interp=cv2.INTER_LINEAR):
def __init__(self, short_edge_length, max_size, interp=cv2.INTER_LINEAR):
"""
Args:
size (int): the size to resize the shortest edge to.
max_size (int): maximum allowed longest edge.
short_edge_length ([int, int]): a [min, max] interval from which to sample the
shortest edge length.
max_size (int): maximum allowed longest edge length.
"""
super(CustomResize, self).__init__()
if isinstance(short_edge_length, int):
short_edge_length = (short_edge_length, short_edge_length)
self._init(locals())
def _get_augment_params(self, img):
h, w = img.shape[:2]
scale = self.size * 1.0 / min(h, w)
size = self.rng.randint(
self.short_edge_length[0], self.short_edge_length[1] + 1)
scale = size * 1.0 / min(h, w)
if h < w:
newh, neww = self.size, scale * w
newh, neww = size, scale * w
else:
newh, neww = scale * h, self.size
newh, neww = scale * h, size
if max(newh, neww) > self.max_size:
scale = self.max_size * 1.0 / max(newh, neww)
newh = newh * scale
......
......@@ -107,7 +107,8 @@ _C.TRAIN.NUM_EVALS = 20 # number of evaluations to run during training
# preprocessing --------------------
# Alternative old (worse & faster) setting: 600, 1024
_C.PREPROC.SHORT_EDGE_SIZE = 800
_C.PREPROC.TRAIN_SHORT_EDGE_SIZE = [800, 800] # [min, max] to sample from
_C.PREPROC.TEST_SHORT_EDGE_SIZE = 800
_C.PREPROC.MAX_SIZE = 1333
# mean and std in RGB order.
# Un-scaled version: [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
......
......@@ -311,7 +311,7 @@ def get_train_dataflow():
ds = DataFromList(roidbs, shuffle=True)
aug = imgaug.AugmentorList(
[CustomResize(cfg.PREPROC.SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE),
[CustomResize(cfg.PREPROC.TRAIN_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE),
imgaug.Flip(horiz=True)])
def preprocess(roidb):
......
......@@ -70,7 +70,7 @@ def detect_one_image(img, model_func):
"""
orig_shape = img.shape[:2]
resizer = CustomResize(cfg.PREPROC.SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE)
resizer = CustomResize(cfg.PREPROC.TEST_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE)
resized_img = resizer.augment(img)
scale = np.sqrt(resized_img.shape[0] * 1.0 / img.shape[0] * resized_img.shape[1] / img.shape[1])
boxes, probs, labels, *masks = model_func(resized_img)
......
......@@ -162,7 +162,7 @@ class ResNetC4Model(DetectionModel):
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS),
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32),
rcnn_labels, matched_gt_boxes)
if is_training:
......@@ -293,7 +293,7 @@ class ResNetFPNModel(DetectionModel):
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head_func(
'fastrcnn', roi_feature_fastrcnn, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS),
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32),
rcnn_labels, matched_gt_boxes)
if is_training:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment