Commit d7a020fc authored by Yuxin Wu's avatar Yuxin Wu

refactor imgaug

parent f6b502d7
...@@ -7,7 +7,7 @@ import cv2 ...@@ -7,7 +7,7 @@ import cv2
import copy import copy
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow
from .common import MapDataComponent from .common import MapDataComponent
from .imgaug import AugmentorList, Image from .imgaug import AugmentorList
__all__ = ['ImageFromFile', 'AugmentImageComponent'] __all__ = ['ImageFromFile', 'AugmentImageComponent']
......
...@@ -5,21 +5,26 @@ ...@@ -5,21 +5,26 @@
import sys import sys
import cv2 import cv2
from . import AugmentorList, Image from . import AugmentorList
from .crop import * from .crop import *
from .imgproc import *
from .noname import *
from .deform import *
anchors = [(0.2, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)] anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([ augmentors = AugmentorList([
#Contrast((0.2,1.8)), Contrast((0.8,1.2)),
#Flip(horiz=True), Flip(horiz=True),
#GaussianDeform(anchors, (360,480), 1, randrange=10) GaussianDeform(anchors, (360,480), 0.2, randrange=20),
RandomCropRandomShape(0.3) #RandomCropRandomShape(0.3)
]) ])
while True: img = cv2.imread(sys.argv[1])
img = cv2.imread(sys.argv[1]) newimg, prms = augmentors._augment_return_params(img)
img = Image(img) cv2.imshow(" ", newimg.astype('uint8'))
augmentors.augment(img) cv2.waitKey()
cv2.imshow(" ", img.arr.astype('uint8'))
cv2.waitKey() newimg = augmentors._augment(img, prms)
cv2.imshow(" ", newimg.astype('uint8'))
cv2.waitKey()
...@@ -5,20 +5,7 @@ ...@@ -5,20 +5,7 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ...utils import get_rng from ...utils import get_rng
__all__ = ['Image', 'ImageAugmentor', 'AugmentorList'] __all__ = ['ImageAugmentor', 'AugmentorList']
class Image(object):
""" An image class with attributes, for augmentor to operate on.
Attributes (such as coordinates) have to be augmented acoordingly by
the augmentor, if necessary.
"""
def __init__(self, arr, coords=None):
"""
:param arr: the image array. Expected to be of [h, w, c] or [h, w]
:param coords: keypoint coordinates.
"""
self.arr = arr
self.coords = coords
class ImageAugmentor(object): class ImageAugmentor(object):
""" Base class for an image augmentor""" """ Base class for an image augmentor"""
...@@ -40,16 +27,33 @@ class ImageAugmentor(object): ...@@ -40,16 +27,33 @@ class ImageAugmentor(object):
def augment(self, img): def augment(self, img):
""" """
Perform augmentation on the image in-place. Perform augmentation on the image in-place.
:param img: an `Image` instance. :param img: an [h,w] or [h,w,c] image
:returns: the augmented `Image` instance. arr will always be of type :returns: the augmented image, always of type 'float32'
'float32' after augmentation.
""" """
self._augment(img) img, params = self._augment_return_params(img)
return img return img
def _augment_return_params(self, img):
"""
Augment the image and return both image and params
"""
prms = self._get_augment_params(img)
return (self._augment(img, prms), prms)
@abstractmethod @abstractmethod
def _augment(self, img): def _augment(self, img, param):
pass """
augment with the given param and return the new image
"""
def _get_augment_params(self, img):
"""
get the augmentor parameters
"""
return None
def _fprop_coord(self, coord, param):
return coord
def _rand_range(self, low=1.0, high=None, size=None): def _rand_range(self, low=1.0, high=None, size=None):
if high is None: if high is None:
...@@ -67,12 +71,24 @@ class AugmentorList(ImageAugmentor): ...@@ -67,12 +71,24 @@ class AugmentorList(ImageAugmentor):
:param augmentors: list of `ImageAugmentor` instance to be applied :param augmentors: list of `ImageAugmentor` instance to be applied
""" """
self.augs = augmentors self.augs = augmentors
super(AugmentorList, self).__init__()
def _augment(self, img): def _get_augment_params(self, img):
assert img.arr.ndim in [2, 3], img.arr.ndim raise RuntimeError("Cannot simply get parameters of a AugmentorList!")
img.arr = img.arr.astype('float32')
for aug in self.augs: def _augment_return_params(self, img):
aug.augment(img) prms = []
for a in self.augs:
img, prm = a._augment_return_params(img)
prms.append(prm)
return img, prms
def _augment(self, img, param):
assert img.ndim in [2, 3], img.ndim
img = img.astype('float32')
for aug, prm in zip(self.augs, param):
img = aug._augment(img, prm)
return img
def reset_state(self): def reset_state(self):
""" Will reset state of each augmentor """ """ Will reset state of each augmentor """
......
...@@ -18,13 +18,18 @@ class RandomCrop(ImageAugmentor): ...@@ -18,13 +18,18 @@ class RandomCrop(ImageAugmentor):
""" """
self._init(locals()) self._init(locals())
def _augment(self, img): def _get_augment_params(self, img):
orig_shape = img.arr.shape orig_shape = img.shape
h0 = self.rng.randint(0, orig_shape[0] - self.crop_shape[0]) h0 = self.rng.randint(0, orig_shape[0] - self.crop_shape[0])
w0 = self.rng.randint(0, orig_shape[1] - self.crop_shape[1]) w0 = self.rng.randint(0, orig_shape[1] - self.crop_shape[1])
img.arr = img.arr[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]] return (h0, w0)
if img.coords:
raise NotImplementedError() def _augment(self, img, param):
h0, w0 = param
return img[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class CenterCrop(ImageAugmentor): class CenterCrop(ImageAugmentor):
""" Crop the image at the center""" """ Crop the image at the center"""
...@@ -34,13 +39,14 @@ class CenterCrop(ImageAugmentor): ...@@ -34,13 +39,14 @@ class CenterCrop(ImageAugmentor):
""" """
self._init(locals()) self._init(locals())
def _augment(self, img): def _augment(self, img, _):
orig_shape = img.arr.shape orig_shape = img.shape
h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5) h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5)
w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5) w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5)
img.arr = img.arr[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]] return img[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
if img.coords:
raise NotImplementedError() def _fprop_coord(self, coord, param):
raise NotImplementedError()
class FixedCrop(ImageAugmentor): class FixedCrop(ImageAugmentor):
""" Crop a rectangle at a given location""" """ Crop a rectangle at a given location"""
...@@ -52,12 +58,13 @@ class FixedCrop(ImageAugmentor): ...@@ -52,12 +58,13 @@ class FixedCrop(ImageAugmentor):
""" """
self._init(locals()) self._init(locals())
def _augment(self, img): def _augment(self, img, _):
orig_shape = img.arr.shape orig_shape = img.shape
img.arr = img.arr[self.rect.y0: self.rect.y1+1, return img[self.rect.y0: self.rect.y1+1,
self.rect.x0: self.rect.x0+1] self.rect.x0: self.rect.x0+1]
if img.coords:
raise NotImplementedError() def _fprop_coord(self, coord, param):
raise NotImplementedError()
def perturb_BB(image_shape, bb, max_pertub_pixel, def perturb_BB(image_shape, bb, max_pertub_pixel,
rng=None, max_aspect_ratio_diff=0.3, rng=None, max_aspect_ratio_diff=0.3,
...@@ -104,16 +111,19 @@ class RandomCropRandomShape(ImageAugmentor): ...@@ -104,16 +111,19 @@ class RandomCropRandomShape(ImageAugmentor):
""" """
self._init(locals()) self._init(locals())
def _augment(self, img): def _get_augment_params(self, img):
shape = img.arr.shape[:2] shape = img.shape[:2]
box = Rect(0, 0, shape[1] - 1, shape[0] - 1) box = Rect(0, 0, shape[1] - 1, shape[0] - 1)
dist = self.perturb_ratio * np.sqrt(shape[0]*shape[1]) dist = self.perturb_ratio * np.sqrt(shape[0]*shape[1])
newbox = perturb_BB(shape, box, dist, newbox = perturb_BB(shape, box, dist,
self.rng, self.max_aspect_ratio_diff) self.rng, self.max_aspect_ratio_diff)
return newbox
def _augment(self, img, newbox):
return newbox.roi(img)
img.arr = newbox.roi(img.arr) def _fprop_coord(self, coord, param):
if img.coords: raise NotImplementedError()
raise NotImplementedError()
if __name__ == '__main__': if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50)) print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
...@@ -81,10 +81,15 @@ class GaussianDeform(ImageAugmentor): ...@@ -81,10 +81,15 @@ class GaussianDeform(ImageAugmentor):
else: else:
self.randrange = randrange self.randrange = randrange
def _augment(self, img): def _get_augment_params(self, img):
if img.coords:
raise NotImplementedError()
v = self.rng.rand(self.K, 2).astype('float32') - 0.5 v = self.rng.rand(self.K, 2).astype('float32') - 0.5
v = v * 2 * self.randrange v = v * 2 * self.randrange
return v
def _augment(self, img, v):
grid = self.grid + np.dot(self.gws, v) grid = self.grid + np.dot(self.gws, v)
img.arr = np_sample(img.arr, grid) print(grid)
return np_sample(img, grid)
def _fprop_coord(self, coord, param):
raise NotImplementedError()
...@@ -18,11 +18,15 @@ class Brightness(ImageAugmentor): ...@@ -18,11 +18,15 @@ class Brightness(ImageAugmentor):
assert delta > 0 assert delta > 0
self._init(locals()) self._init(locals())
def _augment(self, img): def _get_augment_params(self, img):
v = self._rand_range(-self.delta, self.delta) v = self._rand_range(-self.delta, self.delta)
img.arr += v return v
def _augment(self, img, v):
img += v
if self.clip: if self.clip:
img.arr = np.clip(img.arr, 0, 255) img = np.clip(img, 0, 255)
return img
class Contrast(ImageAugmentor): class Contrast(ImageAugmentor):
""" """
...@@ -36,13 +40,15 @@ class Contrast(ImageAugmentor): ...@@ -36,13 +40,15 @@ class Contrast(ImageAugmentor):
""" """
self._init(locals()) self._init(locals())
def _augment(self, img): def _get_augment_params(self, img):
arr = img.arr return self._rand_range(*self.factor_range)
r = self._rand_range(*self.factor_range)
mean = np.mean(arr, axis=(0,1), keepdims=True) def _augment(self, img, r):
img.arr = (arr - mean) * r + mean mean = np.mean(img, axis=(0,1), keepdims=True)
img = (img - mean) * r + mean
if self.clip: if self.clip:
img.arr = np.clip(img.arr, 0, 255) img = np.clip(img, 0, 255)
return img
class MeanVarianceNormalize(ImageAugmentor): class MeanVarianceNormalize(ImageAugmentor):
""" """
...@@ -56,12 +62,13 @@ class MeanVarianceNormalize(ImageAugmentor): ...@@ -56,12 +62,13 @@ class MeanVarianceNormalize(ImageAugmentor):
""" """
self.all_channel = all_channel self.all_channel = all_channel
def _augment(self, img): def _augment(self, img, _):
if self.all_channel: if self.all_channel:
mean = np.mean(img.arr) mean = np.mean(img)
std = np.std(img.arr) std = np.std(img)
else: else:
mean = np.mean(img.arr, axis=(0,1), keepdims=True) mean = np.mean(img, axis=(0,1), keepdims=True)
std = np.std(img.arr, axis=(0,1), keepdims=True) std = np.std(img, axis=(0,1), keepdims=True)
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.arr.shape))) std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.shape)))
img.arr = (img.arr - mean) / std img = (img - mean) / std
return img
...@@ -31,11 +31,16 @@ class Flip(ImageAugmentor): ...@@ -31,11 +31,16 @@ class Flip(ImageAugmentor):
self.prob = prob self.prob = prob
self._init() self._init()
def _augment(self, img): def _get_augment_params(self, img):
if self._rand_range() < self.prob: return self._rand_range() < self.prob
img.arr = cv2.flip(img.arr, self.code)
if img.coords: def _augment(self, img, do):
raise NotImplementedError() if do:
img = cv2.flip(img, self.code)
return img
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class MapImage(ImageAugmentor): class MapImage(ImageAugmentor):
...@@ -48,8 +53,8 @@ class MapImage(ImageAugmentor): ...@@ -48,8 +53,8 @@ class MapImage(ImageAugmentor):
""" """
self.func = func self.func = func
def _augment(self, img): def _augment(self, img, _):
img.arr = self.func(img.arr) img = self.func(img)
class Resize(ImageAugmentor): class Resize(ImageAugmentor):
...@@ -60,7 +65,7 @@ class Resize(ImageAugmentor): ...@@ -60,7 +65,7 @@ class Resize(ImageAugmentor):
""" """
self._init(locals()) self._init(locals())
def _augment(self, img): def _augment(self, img, _):
img.arr = cv2.resize( img.arr = cv2.resize(
img.arr, self.shape[::-1], img.arr, self.shape[::-1],
interpolation=cv2.INTER_CUBIC) interpolation=cv2.INTER_CUBIC)
...@@ -57,16 +57,16 @@ class CenterPaste(ImageAugmentor): ...@@ -57,16 +57,16 @@ class CenterPaste(ImageAugmentor):
self._init(locals()) self._init(locals())
def _augment(self, img): def _augment(self, img, _):
img_shape = img.arr.shape[:2] img_shape = img.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1] assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
background = self.background_filler.fill( background = self.background_filler.fill(
self.background_shape, img.arr) self.background_shape, img)
h0 = int((self.background_shape[0] - img_shape[0]) * 0.5) h0 = int((self.background_shape[0] - img_shape[0]) * 0.5)
w0 = int((self.background_shape[1] - img_shape[1]) * 0.5) w0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img.arr background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img
img.arr = background img = background
if img.coords: if img.coords:
raise NotImplementedError() raise NotImplementedError()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment