Commit d7a020fc authored by Yuxin Wu's avatar Yuxin Wu

refactor imgaug

parent f6b502d7
......@@ -7,7 +7,7 @@ import cv2
import copy
from .base import DataFlow, ProxyDataFlow
from .common import MapDataComponent
from .imgaug import AugmentorList, Image
from .imgaug import AugmentorList
__all__ = ['ImageFromFile', 'AugmentImageComponent']
......
......@@ -5,21 +5,26 @@
import sys
import cv2
from . import AugmentorList, Image
from . import AugmentorList
from .crop import *
from .imgproc import *
from .noname import *
from .deform import *
anchors = [(0.2, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([
#Contrast((0.2,1.8)),
#Flip(horiz=True),
#GaussianDeform(anchors, (360,480), 1, randrange=10)
RandomCropRandomShape(0.3)
Contrast((0.8,1.2)),
Flip(horiz=True),
GaussianDeform(anchors, (360,480), 0.2, randrange=20),
#RandomCropRandomShape(0.3)
])
while True:
img = cv2.imread(sys.argv[1])
img = Image(img)
augmentors.augment(img)
cv2.imshow(" ", img.arr.astype('uint8'))
cv2.waitKey()
img = cv2.imread(sys.argv[1])
newimg, prms = augmentors._augment_return_params(img)
cv2.imshow(" ", newimg.astype('uint8'))
cv2.waitKey()
newimg = augmentors._augment(img, prms)
cv2.imshow(" ", newimg.astype('uint8'))
cv2.waitKey()
......@@ -5,20 +5,7 @@
from abc import abstractmethod, ABCMeta
from ...utils import get_rng
__all__ = ['Image', 'ImageAugmentor', 'AugmentorList']
class Image(object):
""" An image class with attributes, for augmentor to operate on.
Attributes (such as coordinates) have to be augmented acoordingly by
the augmentor, if necessary.
"""
def __init__(self, arr, coords=None):
"""
:param arr: the image array. Expected to be of [h, w, c] or [h, w]
:param coords: keypoint coordinates.
"""
self.arr = arr
self.coords = coords
__all__ = ['ImageAugmentor', 'AugmentorList']
class ImageAugmentor(object):
""" Base class for an image augmentor"""
......@@ -40,16 +27,33 @@ class ImageAugmentor(object):
def augment(self, img):
"""
Perform augmentation on the image in-place.
:param img: an `Image` instance.
:returns: the augmented `Image` instance. arr will always be of type
'float32' after augmentation.
:param img: an [h,w] or [h,w,c] image
:returns: the augmented image, always of type 'float32'
"""
self._augment(img)
img, params = self._augment_return_params(img)
return img
def _augment_return_params(self, img):
"""
Augment the image and return both image and params
"""
prms = self._get_augment_params(img)
return (self._augment(img, prms), prms)
@abstractmethod
def _augment(self, img):
pass
def _augment(self, img, param):
"""
augment with the given param and return the new image
"""
def _get_augment_params(self, img):
"""
get the augmentor parameters
"""
return None
def _fprop_coord(self, coord, param):
return coord
def _rand_range(self, low=1.0, high=None, size=None):
if high is None:
......@@ -67,12 +71,24 @@ class AugmentorList(ImageAugmentor):
:param augmentors: list of `ImageAugmentor` instance to be applied
"""
self.augs = augmentors
super(AugmentorList, self).__init__()
def _augment(self, img):
assert img.arr.ndim in [2, 3], img.arr.ndim
img.arr = img.arr.astype('float32')
for aug in self.augs:
aug.augment(img)
def _get_augment_params(self, img):
raise RuntimeError("Cannot simply get parameters of a AugmentorList!")
def _augment_return_params(self, img):
prms = []
for a in self.augs:
img, prm = a._augment_return_params(img)
prms.append(prm)
return img, prms
def _augment(self, img, param):
assert img.ndim in [2, 3], img.ndim
img = img.astype('float32')
for aug, prm in zip(self.augs, param):
img = aug._augment(img, prm)
return img
def reset_state(self):
""" Will reset state of each augmentor """
......
......@@ -18,13 +18,18 @@ class RandomCrop(ImageAugmentor):
"""
self._init(locals())
def _augment(self, img):
orig_shape = img.arr.shape
def _get_augment_params(self, img):
orig_shape = img.shape
h0 = self.rng.randint(0, orig_shape[0] - self.crop_shape[0])
w0 = self.rng.randint(0, orig_shape[1] - self.crop_shape[1])
img.arr = img.arr[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
if img.coords:
raise NotImplementedError()
return (h0, w0)
def _augment(self, img, param):
h0, w0 = param
return img[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class CenterCrop(ImageAugmentor):
""" Crop the image at the center"""
......@@ -34,13 +39,14 @@ class CenterCrop(ImageAugmentor):
"""
self._init(locals())
def _augment(self, img):
orig_shape = img.arr.shape
def _augment(self, img, _):
orig_shape = img.shape
h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5)
w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5)
img.arr = img.arr[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
if img.coords:
raise NotImplementedError()
return img[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class FixedCrop(ImageAugmentor):
""" Crop a rectangle at a given location"""
......@@ -52,12 +58,13 @@ class FixedCrop(ImageAugmentor):
"""
self._init(locals())
def _augment(self, img):
orig_shape = img.arr.shape
img.arr = img.arr[self.rect.y0: self.rect.y1+1,
self.rect.x0: self.rect.x0+1]
if img.coords:
raise NotImplementedError()
def _augment(self, img, _):
orig_shape = img.shape
return img[self.rect.y0: self.rect.y1+1,
self.rect.x0: self.rect.x0+1]
def _fprop_coord(self, coord, param):
raise NotImplementedError()
def perturb_BB(image_shape, bb, max_pertub_pixel,
rng=None, max_aspect_ratio_diff=0.3,
......@@ -104,16 +111,19 @@ class RandomCropRandomShape(ImageAugmentor):
"""
self._init(locals())
def _augment(self, img):
shape = img.arr.shape[:2]
def _get_augment_params(self, img):
shape = img.shape[:2]
box = Rect(0, 0, shape[1] - 1, shape[0] - 1)
dist = self.perturb_ratio * np.sqrt(shape[0]*shape[1])
newbox = perturb_BB(shape, box, dist,
self.rng, self.max_aspect_ratio_diff)
return newbox
def _augment(self, img, newbox):
return newbox.roi(img)
img.arr = newbox.roi(img.arr)
if img.coords:
raise NotImplementedError()
def _fprop_coord(self, coord, param):
raise NotImplementedError()
if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
......@@ -81,10 +81,15 @@ class GaussianDeform(ImageAugmentor):
else:
self.randrange = randrange
def _augment(self, img):
if img.coords:
raise NotImplementedError()
def _get_augment_params(self, img):
v = self.rng.rand(self.K, 2).astype('float32') - 0.5
v = v * 2 * self.randrange
return v
def _augment(self, img, v):
grid = self.grid + np.dot(self.gws, v)
img.arr = np_sample(img.arr, grid)
print(grid)
return np_sample(img, grid)
def _fprop_coord(self, coord, param):
raise NotImplementedError()
......@@ -18,11 +18,15 @@ class Brightness(ImageAugmentor):
assert delta > 0
self._init(locals())
def _augment(self, img):
def _get_augment_params(self, img):
v = self._rand_range(-self.delta, self.delta)
img.arr += v
return v
def _augment(self, img, v):
img += v
if self.clip:
img.arr = np.clip(img.arr, 0, 255)
img = np.clip(img, 0, 255)
return img
class Contrast(ImageAugmentor):
"""
......@@ -36,13 +40,15 @@ class Contrast(ImageAugmentor):
"""
self._init(locals())
def _augment(self, img):
arr = img.arr
r = self._rand_range(*self.factor_range)
mean = np.mean(arr, axis=(0,1), keepdims=True)
img.arr = (arr - mean) * r + mean
def _get_augment_params(self, img):
return self._rand_range(*self.factor_range)
def _augment(self, img, r):
mean = np.mean(img, axis=(0,1), keepdims=True)
img = (img - mean) * r + mean
if self.clip:
img.arr = np.clip(img.arr, 0, 255)
img = np.clip(img, 0, 255)
return img
class MeanVarianceNormalize(ImageAugmentor):
"""
......@@ -56,12 +62,13 @@ class MeanVarianceNormalize(ImageAugmentor):
"""
self.all_channel = all_channel
def _augment(self, img):
def _augment(self, img, _):
if self.all_channel:
mean = np.mean(img.arr)
std = np.std(img.arr)
mean = np.mean(img)
std = np.std(img)
else:
mean = np.mean(img.arr, axis=(0,1), keepdims=True)
std = np.std(img.arr, axis=(0,1), keepdims=True)
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.arr.shape)))
img.arr = (img.arr - mean) / std
mean = np.mean(img, axis=(0,1), keepdims=True)
std = np.std(img, axis=(0,1), keepdims=True)
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.shape)))
img = (img - mean) / std
return img
......@@ -31,11 +31,16 @@ class Flip(ImageAugmentor):
self.prob = prob
self._init()
def _augment(self, img):
if self._rand_range() < self.prob:
img.arr = cv2.flip(img.arr, self.code)
if img.coords:
raise NotImplementedError()
def _get_augment_params(self, img):
return self._rand_range() < self.prob
def _augment(self, img, do):
if do:
img = cv2.flip(img, self.code)
return img
def _fprop_coord(self, coord, param):
raise NotImplementedError()
class MapImage(ImageAugmentor):
......@@ -48,8 +53,8 @@ class MapImage(ImageAugmentor):
"""
self.func = func
def _augment(self, img):
img.arr = self.func(img.arr)
def _augment(self, img, _):
img = self.func(img)
class Resize(ImageAugmentor):
......@@ -60,7 +65,7 @@ class Resize(ImageAugmentor):
"""
self._init(locals())
def _augment(self, img):
def _augment(self, img, _):
img.arr = cv2.resize(
img.arr, self.shape[::-1],
interpolation=cv2.INTER_CUBIC)
......@@ -57,16 +57,16 @@ class CenterPaste(ImageAugmentor):
self._init(locals())
def _augment(self, img):
img_shape = img.arr.shape[:2]
def _augment(self, img, _):
img_shape = img.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
background = self.background_filler.fill(
self.background_shape, img.arr)
self.background_shape, img)
h0 = int((self.background_shape[0] - img_shape[0]) * 0.5)
w0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img.arr
img.arr = background
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img
img = background
if img.coords:
raise NotImplementedError()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment