Commit be51dd88 authored by Yuxin Wu's avatar Yuxin Wu

Add RandomCutout

parent 1870496f
......@@ -7,10 +7,11 @@ import cv2
from ...utils.argtools import shape2d
from ...utils.develop import log_deprecated
from .base import ImageAugmentor, ImagePlaceholder
from .transform import CropTransform, TransformList, ResizeTransform
from .transform import CropTransform, TransformList, ResizeTransform, PhotometricTransform
from .misc import ResizeShortestEdge
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape', 'GoogleNetRandomCropAndResize']
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape',
'GoogleNetRandomCropAndResize', 'RandomCutout']
class RandomCrop(ImageAugmentor):
......@@ -132,3 +133,44 @@ class GoogleNetRandomCropAndResize(ImageAugmentor):
out_shape = (resize.new_h, resize.new_w)
crop = CenterCrop(self.target_shape).get_transform(ImagePlaceholder(shape=out_shape))
return TransformList([resize, crop])
class RandomCutout(ImageAugmentor):
"""
The cutout augmentation, as described in https://arxiv.org/abs/1708.04552
"""
def __init__(self, h_range, w_range, fill=0.):
"""
Args:
h_range (int or tuple): the height of rectangle to cut.
If a tuple, will randomly sample from this range [low, high)
w_range (int or tuple): similar to above
fill (float): the fill value
"""
super(RandomCutout, self).__init__()
self._init(locals())
def _get_cutout_shape(self):
if isinstance(self.h_range, int):
h = self.h_range
else:
h = self.rng.randint(self.h_range)
if isinstance(self.w_range, int):
w = self.w_range
else:
w = self.rng.randint(self.w_range)
return h, w
@staticmethod
def _cutout(img, y0, x0, h, w, fill):
img[y0:y0 + h, x0:x0 + w] = fill
return img
def get_transform(self, img):
h, w = self._get_cutout_shape()
x0 = self.rng.randint(0, img.shape[1] + 1 - w)
y0 = self.rng.randint(0, img.shape[0] + 1 - h)
return PhotometricTransform(
lambda img: RandomCutout._cutout(img, y0, x0, h, w, self.fill),
"cutout")
......@@ -15,7 +15,8 @@ Legacy alias. Please don't use.
# This legacy augmentor requires us to import base from here, causing circular dependency.
# Should remove this in the future.
__all__ = ["Transform", "ResizeTransform", "CropTransform", "FlipTransform", "TransformList", "TransformFactory"]
__all__ = ["Transform", "ResizeTransform", "CropTransform", "FlipTransform",
"TransformList", "TransformFactory"]
# class WrappedImgFunc(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment