Commit ca7bc07f authored by Yuxin Wu's avatar Yuxin Wu

rgb option for Saturation and Hue (fix #332)

parent f17fab8e
......@@ -167,7 +167,7 @@ def get_data(train_or_test):
imgaug.RandomOrderAug(
[imgaug.Brightness(30, clip=False),
imgaug.Contrast((0.8, 1.2), clip=False),
imgaug.Saturation(0.4),
imgaug.Saturation(0.4, rgb=False),
# rgb-bgr conversion
imgaug.Lighting(0.1,
eigval=[0.2175, 0.0188, 0.0045][::-1],
......
......@@ -158,7 +158,7 @@ def get_data(train_or_test):
imgaug.RandomOrderAug(
[imgaug.Brightness(30, clip=False),
imgaug.Contrast((0.8, 1.2), clip=False),
imgaug.Saturation(0.4),
imgaug.Saturation(0.4, rgb=False),
imgaug.Lighting(0.1,
eigval=[0.2175, 0.0188, 0.0045][::-1],
eigvec=np.array(
......
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor
from ...utils import logger
import numpy as np
import cv2
......@@ -11,24 +12,33 @@ __all__ = ['Hue', 'Brightness', 'Contrast', 'MeanVarianceNormalize',
class Hue(ImageAugmentor):
""" Randomly change color hue of a BGR input.
""" Randomly change color hue.
"""
def __init__(self, range=(0, 180)):
def __init__(self, range=(0, 180), rgb=None):
"""
Args:
range(list or tuple): hue range
rgb (bool): whether input is RGB or BGR.
"""
super(Hue, self).__init__()
if rgb is None:
logger.warn("Hue() now assumes rgb=False, but will by default use rgb=True in the future!")
rgb = False
rgb = bool(rgb)
self._init(locals())
def _get_augment_params(self, _):
return self._rand_range(*self.range)
def _augment(self, img, hue):
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
m = cv2.COLOR_BGR2HSV if not self.rgb else cv2.COLOR_RGB2HSV
hsv = cv2.cvtColor(img, m)
# Note, OpenCV used 0-179 degree instead of 0-359 degree
hsv[..., 0] = (hsv[..., 0] + hue) % 180
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
m = cv2.COLOR_HSV2BGR if not self.rgb else cv2.COLOR_HSV2RGB
img = cv2.cvtColor(hsv, m)
return img
......@@ -174,17 +184,22 @@ class Clip(ImageAugmentor):
class Saturation(ImageAugmentor):
""" Randomly adjust saturation of BGR input.
""" Randomly adjust saturation.
Follows the implementation in `fb.resnet.torch
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`__.
"""
def __init__(self, alpha=0.4):
def __init__(self, alpha=0.4, rgb=None):
"""
Args:
alpha(float): maximum saturation change.
rgb (bool): whether input is RGB or BGR.
"""
super(Saturation, self).__init__()
if rgb is None:
logger.warn("Saturation() now assumes rgb=False, but will by default use rgb=True in the future!")
rgb = False
rgb = bool(rgb)
assert alpha < 1
self._init(locals())
......@@ -193,7 +208,8 @@ class Saturation(ImageAugmentor):
def _augment(self, img, v):
old_dtype = img.dtype
grey = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
m = cv2.COLOR_RGB2GRAY if self.rgb else cv2.COLOR_BGR2GRAY
grey = cv2.cvtColor(img, m)
ret = img * v + (grey * (1 - v))[:, :, np.newaxis]
return ret.astype(old_dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment