Commit 843d44e9 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] train-time scale augmentation

parent 2efc98f7
...@@ -32,22 +32,27 @@ class CustomResize(transform.TransformAugmentorBase): ...@@ -32,22 +32,27 @@ class CustomResize(transform.TransformAugmentorBase):
while avoiding the longest edge to exceed max_size. while avoiding the longest edge to exceed max_size.
""" """
def __init__(self, size, max_size, interp=cv2.INTER_LINEAR): def __init__(self, short_edge_length, max_size, interp=cv2.INTER_LINEAR):
""" """
Args: Args:
size (int): the size to resize the shortest edge to. short_edge_length ([int, int]): a [min, max] interval from which to sample the
max_size (int): maximum allowed longest edge. shortest edge length.
max_size (int): maximum allowed longest edge length.
""" """
super(CustomResize, self).__init__() super(CustomResize, self).__init__()
if isinstance(short_edge_length, int):
short_edge_length = (short_edge_length, short_edge_length)
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
h, w = img.shape[:2] h, w = img.shape[:2]
scale = self.size * 1.0 / min(h, w) size = self.rng.randint(
self.short_edge_length[0], self.short_edge_length[1] + 1)
scale = size * 1.0 / min(h, w)
if h < w: if h < w:
newh, neww = self.size, scale * w newh, neww = size, scale * w
else: else:
newh, neww = scale * h, self.size newh, neww = scale * h, size
if max(newh, neww) > self.max_size: if max(newh, neww) > self.max_size:
scale = self.max_size * 1.0 / max(newh, neww) scale = self.max_size * 1.0 / max(newh, neww)
newh = newh * scale newh = newh * scale
......
...@@ -107,7 +107,8 @@ _C.TRAIN.NUM_EVALS = 20 # number of evaluations to run during training ...@@ -107,7 +107,8 @@ _C.TRAIN.NUM_EVALS = 20 # number of evaluations to run during training
# preprocessing -------------------- # preprocessing --------------------
# Alternative old (worse & faster) setting: 600, 1024 # Alternative old (worse & faster) setting: 600, 1024
_C.PREPROC.SHORT_EDGE_SIZE = 800 _C.PREPROC.TRAIN_SHORT_EDGE_SIZE = [800, 800] # [min, max] to sample from
_C.PREPROC.TEST_SHORT_EDGE_SIZE = 800
_C.PREPROC.MAX_SIZE = 1333 _C.PREPROC.MAX_SIZE = 1333
# mean and std in RGB order. # mean and std in RGB order.
# Un-scaled version: [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] # Un-scaled version: [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
......
...@@ -311,7 +311,7 @@ def get_train_dataflow(): ...@@ -311,7 +311,7 @@ def get_train_dataflow():
ds = DataFromList(roidbs, shuffle=True) ds = DataFromList(roidbs, shuffle=True)
aug = imgaug.AugmentorList( aug = imgaug.AugmentorList(
[CustomResize(cfg.PREPROC.SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE), [CustomResize(cfg.PREPROC.TRAIN_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE),
imgaug.Flip(horiz=True)]) imgaug.Flip(horiz=True)])
def preprocess(roidb): def preprocess(roidb):
......
...@@ -70,7 +70,7 @@ def detect_one_image(img, model_func): ...@@ -70,7 +70,7 @@ def detect_one_image(img, model_func):
""" """
orig_shape = img.shape[:2] orig_shape = img.shape[:2]
resizer = CustomResize(cfg.PREPROC.SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE) resizer = CustomResize(cfg.PREPROC.TEST_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE)
resized_img = resizer.augment(img) resized_img = resizer.augment(img)
scale = np.sqrt(resized_img.shape[0] * 1.0 / img.shape[0] * resized_img.shape[1] / img.shape[1]) scale = np.sqrt(resized_img.shape[0] * 1.0 / img.shape[0] * resized_img.shape[1] / img.shape[1])
boxes, probs, labels, *masks = model_func(resized_img) boxes, probs, labels, *masks = model_func(resized_img)
......
...@@ -162,7 +162,7 @@ class ResNetC4Model(DetectionModel): ...@@ -162,7 +162,7 @@ class ResNetC4Model(DetectionModel):
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CLASS) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits, fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS), tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32),
rcnn_labels, matched_gt_boxes) rcnn_labels, matched_gt_boxes)
if is_training: if is_training:
...@@ -293,7 +293,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -293,7 +293,7 @@ class ResNetFPNModel(DetectionModel):
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head_func( fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head_func(
'fastrcnn', roi_feature_fastrcnn, cfg.DATA.NUM_CLASS) 'fastrcnn', roi_feature_fastrcnn, cfg.DATA.NUM_CLASS)
fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits, fastrcnn_head = FastRCNNHead(rcnn_boxes, fastrcnn_box_logits, fastrcnn_label_logits,
tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS), tf.constant(cfg.FRCNN.BBOX_REG_WEIGHTS, dtype=tf.float32),
rcnn_labels, matched_gt_boxes) rcnn_labels, matched_gt_boxes)
if is_training: if is_training:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment