Commit 2ff9a5f4 authored by Yuxin Wu's avatar Yuxin Wu

speed up dataloader: casting & clip

parent b2d106a3
...@@ -37,12 +37,13 @@ def fbresnet_augmentor(isTrain): ...@@ -37,12 +37,13 @@ def fbresnet_augmentor(isTrain):
if isTrain: if isTrain:
augmentors = [ augmentors = [
imgaug.GoogleNetRandomCropAndResize(interp=interpolation), imgaug.GoogleNetRandomCropAndResize(interp=interpolation),
imgaug.ToFloat32(), # avoid frequent casting in each color augmentation
# It's OK to remove the following augs if your CPU is not fast enough. # It's OK to remove the following augs if your CPU is not fast enough.
# Removing brightness/contrast/saturation does not have a significant effect on accuracy. # Removing brightness/contrast/saturation does not have a significant effect on accuracy.
# Removing lighting leads to a tiny drop in accuracy. # Removing lighting leads to a tiny drop in accuracy.
imgaug.RandomOrderAug( imgaug.RandomOrderAug(
[imgaug.BrightnessScale((0.6, 1.4), clip=False), [imgaug.BrightnessScale((0.6, 1.4)),
imgaug.Contrast((0.6, 1.4), rgb=False, clip=False), imgaug.Contrast((0.6, 1.4), rgb=False),
imgaug.Saturation(0.4, rgb=False), imgaug.Saturation(0.4, rgb=False),
# rgb-bgr conversion for the constants copied from fb.resnet.torch # rgb-bgr conversion for the constants copied from fb.resnet.torch
imgaug.Lighting(0.1, imgaug.Lighting(0.1,
...@@ -54,6 +55,7 @@ def fbresnet_augmentor(isTrain): ...@@ -54,6 +55,7 @@ def fbresnet_augmentor(isTrain):
[-0.5836, -0.6948, 0.4203]], [-0.5836, -0.6948, 0.4203]],
dtype='float32')[::-1, ::-1] dtype='float32')[::-1, ::-1]
)]), )]),
imgaug.ToUint8(),
imgaug.Flip(horiz=True), imgaug.Flip(horiz=True),
] ]
else: else:
......
...@@ -113,7 +113,7 @@ class ImageAugmentor(object): ...@@ -113,7 +113,7 @@ class ImageAugmentor(object):
low, high = 0, low low, high = 0, low
if size is None: if size is None:
size = [] size = []
return self.rng.uniform(low, high, size) return self.rng.uniform(low, high, size).astype("float32")
def __str__(self): def __str__(self):
try: try:
......
...@@ -226,13 +226,14 @@ class Saturation(PhotometricAugmentor): ...@@ -226,13 +226,14 @@ class Saturation(PhotometricAugmentor):
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`__. <https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`__.
""" """
def __init__(self, alpha=0.4, rgb=True): def __init__(self, alpha=0.4, rgb=True, clip=True):
""" """
Args: Args:
alpha(float): maximum saturation change. alpha(float): maximum saturation change.
rgb (bool): whether input is RGB or BGR. rgb (bool): whether input is RGB or BGR.
clip (bool): clip results to [0,255] even when data type is not uint8.
""" """
super(Saturation, self).__init__() super().__init__()
rgb = bool(rgb) rgb = bool(rgb)
assert alpha < 1 assert alpha < 1
self._init(locals()) self._init(locals())
...@@ -245,7 +246,7 @@ class Saturation(PhotometricAugmentor): ...@@ -245,7 +246,7 @@ class Saturation(PhotometricAugmentor):
m = cv2.COLOR_RGB2GRAY if self.rgb else cv2.COLOR_BGR2GRAY m = cv2.COLOR_RGB2GRAY if self.rgb else cv2.COLOR_BGR2GRAY
grey = cv2.cvtColor(img, m) grey = cv2.cvtColor(img, m)
ret = img * v + (grey * (1 - v))[:, :, np.newaxis] ret = img * v + (grey * (1 - v))[:, :, np.newaxis]
if old_dtype == np.uint8: if self.clip or old_dtype == np.uint8:
ret = np.clip(ret, 0, 255) ret = np.clip(ret, 0, 255)
return ret.astype(old_dtype) return ret.astype(old_dtype)
...@@ -258,16 +259,17 @@ class Lighting(PhotometricAugmentor): ...@@ -258,16 +259,17 @@ class Lighting(PhotometricAugmentor):
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184>`__. <https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184>`__.
""" """
def __init__(self, std, eigval, eigvec): def __init__(self, std, eigval, eigvec, clip=True):
""" """
Args: Args:
std (float): maximum standard deviation std (float): maximum standard deviation
eigval: a vector of (3,). The eigenvalues of 3 channels. eigval: a vector of (3,). The eigenvalues of 3 channels.
eigvec: a 3x3 matrix. Each column is one eigen vector. eigvec: a 3x3 matrix. Each column is one eigen vector.
clip (bool): clip results to [0,255] even when data type is not uint8.
""" """
super(Lighting, self).__init__() super(Lighting, self).__init__()
eigval = np.asarray(eigval) eigval = np.asarray(eigval, dtype="float32")
eigvec = np.asarray(eigvec) eigvec = np.asarray(eigvec, dtype="float32")
assert eigval.shape == (3,) assert eigval.shape == (3,)
assert eigvec.shape == (3, 3) assert eigvec.shape == (3, 3)
self._init(locals()) self._init(locals())
...@@ -282,7 +284,7 @@ class Lighting(PhotometricAugmentor): ...@@ -282,7 +284,7 @@ class Lighting(PhotometricAugmentor):
v = v.reshape((3, 1)) v = v.reshape((3, 1))
inc = np.dot(self.eigvec, v).reshape((3,)) inc = np.dot(self.eigvec, v).reshape((3,))
img = np.add(img, inc) img = np.add(img, inc)
if old_dtype == np.uint8: if self.clip or old_dtype == np.uint8:
img = np.clip(img, 0, 255) img = np.clip(img, 0, 255)
return img.astype(old_dtype) return img.astype(old_dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment