Commit 2ff9a5f4 authored by Yuxin Wu's avatar Yuxin Wu

speed up dataloader: casting & clip

parent b2d106a3
......@@ -37,12 +37,13 @@ def fbresnet_augmentor(isTrain):
if isTrain:
augmentors = [
imgaug.GoogleNetRandomCropAndResize(interp=interpolation),
imgaug.ToFloat32(), # avoid frequent casting in each color augmentation
# It's OK to remove the following augs if your CPU is not fast enough.
# Removing brightness/contrast/saturation does not have a significant effect on accuracy.
# Removing lighting leads to a tiny drop in accuracy.
imgaug.RandomOrderAug(
[imgaug.BrightnessScale((0.6, 1.4), clip=False),
imgaug.Contrast((0.6, 1.4), rgb=False, clip=False),
[imgaug.BrightnessScale((0.6, 1.4)),
imgaug.Contrast((0.6, 1.4), rgb=False),
imgaug.Saturation(0.4, rgb=False),
# rgb-bgr conversion for the constants copied from fb.resnet.torch
imgaug.Lighting(0.1,
......@@ -54,6 +55,7 @@ def fbresnet_augmentor(isTrain):
[-0.5836, -0.6948, 0.4203]],
dtype='float32')[::-1, ::-1]
)]),
imgaug.ToUint8(),
imgaug.Flip(horiz=True),
]
else:
......
......@@ -113,7 +113,7 @@ class ImageAugmentor(object):
low, high = 0, low
if size is None:
size = []
return self.rng.uniform(low, high, size)
return self.rng.uniform(low, high, size).astype("float32")
def __str__(self):
try:
......
......@@ -226,13 +226,14 @@ class Saturation(PhotometricAugmentor):
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`__.
"""
def __init__(self, alpha=0.4, rgb=True):
def __init__(self, alpha=0.4, rgb=True, clip=True):
"""
Args:
alpha(float): maximum saturation change.
rgb (bool): whether input is RGB or BGR.
clip (bool): clip results to [0,255] even when data type is not uint8.
"""
super(Saturation, self).__init__()
super().__init__()
rgb = bool(rgb)
assert alpha < 1
self._init(locals())
......@@ -245,7 +246,7 @@ class Saturation(PhotometricAugmentor):
m = cv2.COLOR_RGB2GRAY if self.rgb else cv2.COLOR_BGR2GRAY
grey = cv2.cvtColor(img, m)
ret = img * v + (grey * (1 - v))[:, :, np.newaxis]
if old_dtype == np.uint8:
if self.clip or old_dtype == np.uint8:
ret = np.clip(ret, 0, 255)
return ret.astype(old_dtype)
......@@ -258,16 +259,17 @@ class Lighting(PhotometricAugmentor):
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184>`__.
"""
def __init__(self, std, eigval, eigvec):
def __init__(self, std, eigval, eigvec, clip=True):
"""
Args:
std (float): maximum standard deviation
eigval: a vector of (3,). The eigenvalues of 3 channels.
eigvec: a 3x3 matrix. Each column is one eigen vector.
clip (bool): clip results to [0,255] even when data type is not uint8.
"""
super(Lighting, self).__init__()
eigval = np.asarray(eigval)
eigvec = np.asarray(eigvec)
eigval = np.asarray(eigval, dtype="float32")
eigvec = np.asarray(eigvec, dtype="float32")
assert eigval.shape == (3,)
assert eigvec.shape == (3, 3)
self._init(locals())
......@@ -282,7 +284,7 @@ class Lighting(PhotometricAugmentor):
v = v.reshape((3, 1))
inc = np.dot(self.eigvec, v).reshape((3,))
img = np.add(img, inc)
if old_dtype == np.uint8:
if self.clip or old_dtype == np.uint8:
img = np.clip(img, 0, 255)
return img.astype(old_dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment