Commit e6637136 authored by Yuxin Wu's avatar Yuxin Wu

[imgaug] improve grayscale and gaussianblur

parent aa7e18fc
...@@ -31,16 +31,29 @@ class ColorSpace(PhotometricAugmentor): ...@@ -31,16 +31,29 @@ class ColorSpace(PhotometricAugmentor):
class Grayscale(ColorSpace): class Grayscale(ColorSpace):
""" Convert image to grayscale. """ """ Convert RGB or BGR image to grayscale. """
def __init__(self, keepdims=True, rgb=False): def __init__(self, keepdims=True, rgb=False, keepshape=False):
""" """
Args: Args:
keepdims (bool): return image of shape [H, W, 1] instead of [H, W] keepdims (bool): return image of shape [H, W, 1] instead of [H, W]
rgb (bool): interpret input as RGB instead of the default BGR rgb (bool): interpret input as RGB instead of the default BGR
keepshape (bool): whether to duplicate the gray image into 3 channels
so the result has the same shape as input.
""" """
mode = cv2.COLOR_RGB2GRAY if rgb else cv2.COLOR_BGR2GRAY mode = cv2.COLOR_RGB2GRAY if rgb else cv2.COLOR_BGR2GRAY
if keepshape:
assert keepdims, "keepdims must be True when keepshape==True"
super(Grayscale, self).__init__(mode, keepdims) super(Grayscale, self).__init__(mode, keepdims)
self.keepshape = keepshape
self.rgb = rgb
def _augment(self, img, _):
ret = super()._augment(img, _)
if self.keepshape:
return np.concatenate([ret] * 3, axis=2)
else:
return ret
class ToUint8(PhotometricAugmentor): class ToUint8(PhotometricAugmentor):
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import numpy as np import numpy as np
import cv2 import cv2
from ...utils.develop import log_deprecated
from .base import PhotometricAugmentor from .base import PhotometricAugmentor
__all__ = ['Hue', 'Brightness', 'BrightnessScale', 'Contrast', 'MeanVarianceNormalize', __all__ = ['Hue', 'Brightness', 'BrightnessScale', 'Contrast', 'MeanVarianceNormalize',
...@@ -164,22 +165,36 @@ class MeanVarianceNormalize(PhotometricAugmentor): ...@@ -164,22 +165,36 @@ class MeanVarianceNormalize(PhotometricAugmentor):
class GaussianBlur(PhotometricAugmentor): class GaussianBlur(PhotometricAugmentor):
""" Gaussian blur the image with random window size""" """ Gaussian blur the image with random window size"""
def __init__(self, max_size=3): def __init__(self, size_range=(0, 3), sigma_range=(0, 0), symmetric=True, max_size=None):
""" """
Args: Args:
max_size (int): max possible Gaussian window size would be 2 * max_size + 1 size_range (tuple[int]): Gaussian window size would be 2 * size +
1, where size is randomly sampled from this [low, high) range.
sigma_range (tuple[float]): min,max of the sigma value. 0 means
opencv's default.
symmetric (bool): whether to use the same size & sigma for x and y.
max_size (int): deprecated
""" """
super(GaussianBlur, self).__init__() super(GaussianBlur, self).__init__()
if not isinstance(size_range, (list, tuple)):
size_range = (0, size_range)
assert isinstance(sigma_range, (list, tuple)), sigma_range
if max_size is not None:
log_deprecated("GaussianBlur(max_size=)", "Use size_range= instead!", "2020-09-01")
size_range = (0, max_size)
self._init(locals()) self._init(locals())
def _get_augment_params(self, _): def _get_augment_params(self, _):
sx, sy = self.rng.randint(self.max_size, size=(2,)) size_xy = self.rng.randint(self.size_range[0], self.size_range[1], size=(2,)) * 2 + 1
sx = sx * 2 + 1 sigma_xy = self._rand_range(*self.sigma_range, size=(2,))
sy = sy * 2 + 1 if self.symmetric:
return sx, sy size_xy[1] = size_xy[0]
sigma_xy[1] = sigma_xy[0]
def _augment(self, img, s): return tuple(size_xy), tuple(sigma_xy)
return np.reshape(cv2.GaussianBlur(img, s, sigmaX=0, sigmaY=0,
def _augment(self, img, prm):
size, sigma = prm
return np.reshape(cv2.GaussianBlur(img, size, sigmaX=sigma[0], sigmaY=sigma[1],
borderType=cv2.BORDER_REPLICATE), img.shape) borderType=cv2.BORDER_REPLICATE), img.shape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment