Commit ca7bc07f authored by Yuxin Wu's avatar Yuxin Wu

rgb option for Saturation and Hue (fix #332)

parent f17fab8e
...@@ -167,7 +167,7 @@ def get_data(train_or_test): ...@@ -167,7 +167,7 @@ def get_data(train_or_test):
imgaug.RandomOrderAug( imgaug.RandomOrderAug(
[imgaug.Brightness(30, clip=False), [imgaug.Brightness(30, clip=False),
imgaug.Contrast((0.8, 1.2), clip=False), imgaug.Contrast((0.8, 1.2), clip=False),
imgaug.Saturation(0.4), imgaug.Saturation(0.4, rgb=False),
# rgb-bgr conversion # rgb-bgr conversion
imgaug.Lighting(0.1, imgaug.Lighting(0.1,
eigval=[0.2175, 0.0188, 0.0045][::-1], eigval=[0.2175, 0.0188, 0.0045][::-1],
......
...@@ -158,7 +158,7 @@ def get_data(train_or_test): ...@@ -158,7 +158,7 @@ def get_data(train_or_test):
imgaug.RandomOrderAug( imgaug.RandomOrderAug(
[imgaug.Brightness(30, clip=False), [imgaug.Brightness(30, clip=False),
imgaug.Contrast((0.8, 1.2), clip=False), imgaug.Contrast((0.8, 1.2), clip=False),
imgaug.Saturation(0.4), imgaug.Saturation(0.4, rgb=False),
imgaug.Lighting(0.1, imgaug.Lighting(0.1,
eigval=[0.2175, 0.0188, 0.0045][::-1], eigval=[0.2175, 0.0188, 0.0045][::-1],
eigvec=np.array( eigvec=np.array(
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor from .base import ImageAugmentor
from ...utils import logger
import numpy as np import numpy as np
import cv2 import cv2
...@@ -11,24 +12,33 @@ __all__ = ['Hue', 'Brightness', 'Contrast', 'MeanVarianceNormalize', ...@@ -11,24 +12,33 @@ __all__ = ['Hue', 'Brightness', 'Contrast', 'MeanVarianceNormalize',
class Hue(ImageAugmentor): class Hue(ImageAugmentor):
""" Randomly change color hue of a BGR input. """ Randomly change color hue.
""" """
def __init__(self, range=(0, 180)): def __init__(self, range=(0, 180), rgb=None):
""" """
Args: Args:
range(list or tuple): hue range range(list or tuple): hue range
rgb (bool): whether input is RGB or BGR.
""" """
super(Hue, self).__init__()
if rgb is None:
logger.warn("Hue() now assumes rgb=False, but will by default use rgb=True in the future!")
rgb = False
rgb = bool(rgb)
self._init(locals()) self._init(locals())
def _get_augment_params(self, _): def _get_augment_params(self, _):
return self._rand_range(*self.range) return self._rand_range(*self.range)
def _augment(self, img, hue): def _augment(self, img, hue):
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) m = cv2.COLOR_BGR2HSV if not self.rgb else cv2.COLOR_RGB2HSV
hsv = cv2.cvtColor(img, m)
# Note, OpenCV used 0-179 degree instead of 0-359 degree # Note, OpenCV used 0-179 degree instead of 0-359 degree
hsv[..., 0] = (hsv[..., 0] + hue) % 180 hsv[..., 0] = (hsv[..., 0] + hue) % 180
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
m = cv2.COLOR_HSV2BGR if not self.rgb else cv2.COLOR_HSV2RGB
img = cv2.cvtColor(hsv, m)
return img return img
...@@ -174,17 +184,22 @@ class Clip(ImageAugmentor): ...@@ -174,17 +184,22 @@ class Clip(ImageAugmentor):
class Saturation(ImageAugmentor): class Saturation(ImageAugmentor):
""" Randomly adjust saturation of BGR input. """ Randomly adjust saturation.
Follows the implementation in `fb.resnet.torch Follows the implementation in `fb.resnet.torch
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`__. <https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`__.
""" """
def __init__(self, alpha=0.4): def __init__(self, alpha=0.4, rgb=None):
""" """
Args: Args:
alpha(float): maximum saturation change. alpha(float): maximum saturation change.
rgb (bool): whether input is RGB or BGR.
""" """
super(Saturation, self).__init__() super(Saturation, self).__init__()
if rgb is None:
logger.warn("Saturation() now assumes rgb=False, but will by default use rgb=True in the future!")
rgb = False
rgb = bool(rgb)
assert alpha < 1 assert alpha < 1
self._init(locals()) self._init(locals())
...@@ -193,7 +208,8 @@ class Saturation(ImageAugmentor): ...@@ -193,7 +208,8 @@ class Saturation(ImageAugmentor):
def _augment(self, img, v): def _augment(self, img, v):
old_dtype = img.dtype old_dtype = img.dtype
grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) m = cv2.COLOR_RGB2GRAY if self.rgb else cv2.COLOR_BGR2GRAY
grey = cv2.cvtColor(img, m)
ret = img * v + (grey * (1 - v))[:, :, np.newaxis] ret = img * v + (grey * (1 - v))[:, :, np.newaxis]
return ret.astype(old_dtype) return ret.astype(old_dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment