Commit be51dd88 authored by Yuxin Wu's avatar Yuxin Wu

Add RandomCutout

parent 1870496f
...@@ -7,10 +7,11 @@ import cv2 ...@@ -7,10 +7,11 @@ import cv2
from ...utils.argtools import shape2d from ...utils.argtools import shape2d
from ...utils.develop import log_deprecated from ...utils.develop import log_deprecated
from .base import ImageAugmentor, ImagePlaceholder from .base import ImageAugmentor, ImagePlaceholder
from .transform import CropTransform, TransformList, ResizeTransform from .transform import CropTransform, TransformList, ResizeTransform, PhotometricTransform
from .misc import ResizeShortestEdge from .misc import ResizeShortestEdge
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape', 'GoogleNetRandomCropAndResize'] __all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape',
'GoogleNetRandomCropAndResize', 'RandomCutout']
class RandomCrop(ImageAugmentor): class RandomCrop(ImageAugmentor):
...@@ -132,3 +133,44 @@ class GoogleNetRandomCropAndResize(ImageAugmentor): ...@@ -132,3 +133,44 @@ class GoogleNetRandomCropAndResize(ImageAugmentor):
out_shape = (resize.new_h, resize.new_w) out_shape = (resize.new_h, resize.new_w)
crop = CenterCrop(self.target_shape).get_transform(ImagePlaceholder(shape=out_shape)) crop = CenterCrop(self.target_shape).get_transform(ImagePlaceholder(shape=out_shape))
return TransformList([resize, crop]) return TransformList([resize, crop])
class RandomCutout(ImageAugmentor):
"""
The cutout augmentation, as described in https://arxiv.org/abs/1708.04552
"""
def __init__(self, h_range, w_range, fill=0.):
"""
Args:
h_range (int or tuple): the height of rectangle to cut.
If a tuple, will randomly sample from this range [low, high)
w_range (int or tuple): similar to above
fill (float): the fill value
"""
super(RandomCutout, self).__init__()
self._init(locals())
def _get_cutout_shape(self):
if isinstance(self.h_range, int):
h = self.h_range
else:
h = self.rng.randint(self.h_range)
if isinstance(self.w_range, int):
w = self.w_range
else:
w = self.rng.randint(self.w_range)
return h, w
@staticmethod
def _cutout(img, y0, x0, h, w, fill):
img[y0:y0 + h, x0:x0 + w] = fill
return img
def get_transform(self, img):
h, w = self._get_cutout_shape()
x0 = self.rng.randint(0, img.shape[1] + 1 - w)
y0 = self.rng.randint(0, img.shape[0] + 1 - h)
return PhotometricTransform(
lambda img: RandomCutout._cutout(img, y0, x0, h, w, self.fill),
"cutout")
...@@ -15,7 +15,8 @@ Legacy alias. Please don't use. ...@@ -15,7 +15,8 @@ Legacy alias. Please don't use.
# This legacy augmentor requires us to import base from here, causing circular dependency. # This legacy augmentor requires us to import base from here, causing circular dependency.
# Should remove this in the future. # Should remove this in the future.
__all__ = ["Transform", "ResizeTransform", "CropTransform", "FlipTransform", "TransformList", "TransformFactory"] __all__ = ["Transform", "ResizeTransform", "CropTransform", "FlipTransform",
"TransformList", "TransformFactory"]
# class WrappedImgFunc(object): # class WrappedImgFunc(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment