Commit e6637136 authored by Yuxin Wu's avatar Yuxin Wu

[imgaug] improve grayscale and gaussianblur

parent aa7e18fc
......@@ -31,16 +31,29 @@ class ColorSpace(PhotometricAugmentor):
class Grayscale(ColorSpace):
""" Convert image to grayscale. """
""" Convert RGB or BGR image to grayscale. """
def __init__(self, keepdims=True, rgb=False):
def __init__(self, keepdims=True, rgb=False, keepshape=False):
"""
Args:
keepdims (bool): return image of shape [H, W, 1] instead of [H, W]
rgb (bool): interpret input as RGB instead of the default BGR
keepshape (bool): whether to duplicate the gray image into 3 channels
so the result has the same shape as input.
"""
mode = cv2.COLOR_RGB2GRAY if rgb else cv2.COLOR_BGR2GRAY
if keepshape:
assert keepdims, "keepdims must be True when keepshape==True"
super(Grayscale, self).__init__(mode, keepdims)
self.keepshape = keepshape
self.rgb = rgb
def _augment(self, img, _):
ret = super()._augment(img, _)
if self.keepshape:
return np.concatenate([ret] * 3, axis=2)
else:
return ret
class ToUint8(PhotometricAugmentor):
......
......@@ -5,6 +5,7 @@
import numpy as np
import cv2
from ...utils.develop import log_deprecated
from .base import PhotometricAugmentor
__all__ = ['Hue', 'Brightness', 'BrightnessScale', 'Contrast', 'MeanVarianceNormalize',
......@@ -164,22 +165,36 @@ class MeanVarianceNormalize(PhotometricAugmentor):
class GaussianBlur(PhotometricAugmentor):
""" Gaussian blur the image with random window size"""
def __init__(self, max_size=3):
def __init__(self, size_range=(0, 3), sigma_range=(0, 0), symmetric=True, max_size=None):
"""
Args:
max_size (int): max possible Gaussian window size would be 2 * max_size + 1
size_range (tuple[int]): Gaussian window size would be 2 * size +
1, where size is randomly sampled from this [low, high) range.
sigma_range (tuple[float]): min,max of the sigma value. 0 means
opencv's default.
symmetric (bool): whether to use the same size & sigma for x and y.
max_size (int): deprecated
"""
super(GaussianBlur, self).__init__()
if not isinstance(size_range, (list, tuple)):
size_range = (0, size_range)
assert isinstance(sigma_range, (list, tuple)), sigma_range
if max_size is not None:
log_deprecated("GaussianBlur(max_size=)", "Use size_range= instead!", "2020-09-01")
size_range = (0, max_size)
self._init(locals())
def _get_augment_params(self, _):
sx, sy = self.rng.randint(self.max_size, size=(2,))
sx = sx * 2 + 1
sy = sy * 2 + 1
return sx, sy
def _augment(self, img, s):
return np.reshape(cv2.GaussianBlur(img, s, sigmaX=0, sigmaY=0,
size_xy = self.rng.randint(self.size_range[0], self.size_range[1], size=(2,)) * 2 + 1
sigma_xy = self._rand_range(*self.sigma_range, size=(2,))
if self.symmetric:
size_xy[1] = size_xy[0]
sigma_xy[1] = sigma_xy[0]
return tuple(size_xy), tuple(sigma_xy)
def _augment(self, img, prm):
size, sigma = prm
return np.reshape(cv2.GaussianBlur(img, size, sigmaX=sigma[0], sigmaY=sigma[1],
borderType=cv2.BORDER_REPLICATE), img.shape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment