Commit 1e91d921 authored by Sujay Narumanchi's avatar Sujay Narumanchi Committed by Yuxin Wu

Minor bugfix imgaug (#24)

* Fix missing super class init call in image augmentors

* Add salt and pepper noise augmentor

* Improve readability for salt pepper noise, make augmentation in place
parent de1a5acd
...@@ -10,6 +10,7 @@ from .crop import * ...@@ -10,6 +10,7 @@ from .crop import *
from .imgproc import * from .imgproc import *
from .noname import * from .noname import *
from .deform import * from .deform import *
from .noise import SaltPepperNoise
anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)] anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
...@@ -17,7 +18,8 @@ augmentors = AugmentorList([ ...@@ -17,7 +18,8 @@ augmentors = AugmentorList([
Contrast((0.8,1.2)), Contrast((0.8,1.2)),
Flip(horiz=True), Flip(horiz=True),
GaussianDeform(anchors, (360,480), 0.2, randrange=20), GaussianDeform(anchors, (360,480), 0.2, randrange=20),
#RandomCropRandomShape(0.3) #RandomCropRandomShape(0.3),
SaltPepperNoise()
]) ])
img = cv2.imread(sys.argv[1]) img = cv2.imread(sys.argv[1])
......
...@@ -17,6 +17,7 @@ class RandomCrop(ImageAugmentor): ...@@ -17,6 +17,7 @@ class RandomCrop(ImageAugmentor):
""" """
:param crop_shape: a shape like (h, w) :param crop_shape: a shape like (h, w)
""" """
super(RandomCrop, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
...@@ -120,6 +121,7 @@ class RandomCropRandomShape(ImageAugmentor): ...@@ -120,6 +121,7 @@ class RandomCropRandomShape(ImageAugmentor):
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)] :param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
:param max_aspect_ratio_diff: keep aspect ratio within the range :param max_aspect_ratio_diff: keep aspect ratio within the range
""" """
super(RandomCropRandomShape, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
......
...@@ -19,6 +19,7 @@ class Rotation(ImageAugmentor): ...@@ -19,6 +19,7 @@ class Rotation(ImageAugmentor):
:param max_deg: max abs value of the rotation degree :param max_deg: max abs value of the rotation degree
:param center_range: the location of the rotation center :param center_range: the location of the rotation center
""" """
super(Rotation, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
...@@ -37,6 +38,7 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -37,6 +38,7 @@ class RotationAndCropValid(ImageAugmentor):
This will produce images of different shapes. This will produce images of different shapes.
""" """
def __init__(self, max_deg, interp=cv2.INTER_CUBIC): def __init__(self, max_deg, interp=cv2.INTER_CUBIC):
super(RotationAndCropValid, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
......
...@@ -17,6 +17,7 @@ class Brightness(ImageAugmentor): ...@@ -17,6 +17,7 @@ class Brightness(ImageAugmentor):
""" """
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True. Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
""" """
super(Brightness, self).__init__()
assert delta > 0 assert delta > 0
self._init(locals()) self._init(locals())
...@@ -40,6 +41,7 @@ class Contrast(ImageAugmentor): ...@@ -40,6 +41,7 @@ class Contrast(ImageAugmentor):
:param factor_range: an interval to random sample the `contrast_factor`. :param factor_range: an interval to random sample the `contrast_factor`.
:param clip: boolean. :param clip: boolean.
""" """
super(Contrast, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
...@@ -79,6 +81,7 @@ class MeanVarianceNormalize(ImageAugmentor): ...@@ -79,6 +81,7 @@ class MeanVarianceNormalize(ImageAugmentor):
class GaussianBlur(ImageAugmentor): class GaussianBlur(ImageAugmentor):
def __init__(self, max_size=3): def __init__(self, max_size=3):
""":params max_size: (maximum kernel size-1)/2""" """:params max_size: (maximum kernel size-1)/2"""
super(GaussianBlur, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
...@@ -94,9 +97,12 @@ class GaussianBlur(ImageAugmentor): ...@@ -94,9 +97,12 @@ class GaussianBlur(ImageAugmentor):
class Gamma(ImageAugmentor): class Gamma(ImageAugmentor):
def __init__(self, range=(-0.5, 0.5)): def __init__(self, range=(-0.5, 0.5)):
super(Gamma, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, _): def _get_augment_params(self, _):
return self._rand_range(*self.range) return self._rand_range(*self.range)
def _augment(self, img, gamma): def _augment(self, img, gamma):
lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8') lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8')
img = np.clip(img, 0, 255).astype('uint8') img = np.clip(img, 0, 255).astype('uint8')
......
...@@ -15,6 +15,7 @@ class Identity(ImageAugmentor): ...@@ -15,6 +15,7 @@ class Identity(ImageAugmentor):
class RandomApplyAug(ImageAugmentor): class RandomApplyAug(ImageAugmentor):
""" Randomly apply the augmentor with a prob. Otherwise do nothing""" """ Randomly apply the augmentor with a prob. Otherwise do nothing"""
def __init__(self, aug, prob): def __init__(self, aug, prob):
super(RandomApplyAug, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
...@@ -40,6 +41,7 @@ class RandomChooseAug(ImageAugmentor): ...@@ -40,6 +41,7 @@ class RandomChooseAug(ImageAugmentor):
""" """
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple :param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
""" """
super(RandomChooseAug, self).__init__()
if isinstance(aug_lists[0], (tuple, list)): if isinstance(aug_lists[0], (tuple, list)):
prob = [k[1] for k in aug_lists] prob = [k[1] for k in aug_lists]
aug_lists = [k[0] for k in aug_lists] aug_lists = [k[0] for k in aug_lists]
......
...@@ -11,6 +11,7 @@ __all__ = ['JpegNoise', 'GaussianNoise'] ...@@ -11,6 +11,7 @@ __all__ = ['JpegNoise', 'GaussianNoise']
class JpegNoise(ImageAugmentor): class JpegNoise(ImageAugmentor):
def __init__(self, quality_range=(40, 100)): def __init__(self, quality_range=(40, 100)):
super(JpegNoise, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
...@@ -23,6 +24,7 @@ class JpegNoise(ImageAugmentor): ...@@ -23,6 +24,7 @@ class JpegNoise(ImageAugmentor):
class GaussianNoise(ImageAugmentor): class GaussianNoise(ImageAugmentor):
def __init__(self, scale=10, clip=True): def __init__(self, scale=10, clip=True):
super(GaussianNoise, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
...@@ -33,3 +35,17 @@ class GaussianNoise(ImageAugmentor): ...@@ -33,3 +35,17 @@ class GaussianNoise(ImageAugmentor):
if self.clip: if self.clip:
ret = np.clip(ret, 0, 255) ret = np.clip(ret, 0, 255)
return ret return ret
class SaltPepperNoise(ImageAugmentor):
def __init__(self, white_prob=0.05, black_prob=0.05):
assert white_prob + black_prob <= 1, "Sum of probabilities cannot be greater than 1"
super(SaltPepperNoise, self).__init__()
self._init(locals())
def _get_augment_params(self, img):
return self.rng.uniform(low=0, high=1, size=img.shape)
def _augment(self, img, param):
img[param > (1 - self.white_prob)] = 255
img[param < self.black_prob] = 0
return img
...@@ -21,6 +21,7 @@ class Flip(ImageAugmentor): ...@@ -21,6 +21,7 @@ class Flip(ImageAugmentor):
:param vert: whether or not apply vertical flip. :param vert: whether or not apply vertical flip.
:param prob: probability of flip. :param prob: probability of flip.
""" """
super(Flip, self).__init__()
if horiz and vert: if horiz and vert:
raise ValueError("Please use two Flip instead.") raise ValueError("Please use two Flip instead.")
elif horiz: elif horiz:
...@@ -66,6 +67,7 @@ class RandomResize(ImageAugmentor): ...@@ -66,6 +67,7 @@ class RandomResize(ImageAugmentor):
:param minimum: (xmin, ymin). Avoid scaling down too much. :param minimum: (xmin, ymin). Avoid scaling down too much.
:param aspect_ratio_thres: at most change k=20% aspect ratio :param aspect_ratio_thres: at most change k=20% aspect ratio
""" """
super(RandomResize, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment