Commit faf00e48 authored by haamoon's avatar haamoon Committed by Yuxin Wu

add _augment_coords implementation for many ImageAugmentor subclasses (#335)

* add _augment_coords implementation for many ImageAugmentor subclasses

* Fxied style issue

* Fixed style issues

* Added AugmentImageCoordinates

* Fixed stype issues

* remove unused comments
parent 4c2bcb94
......@@ -9,7 +9,7 @@ from .common import MapDataComponent, MapData
from ..utils import logger
from ..utils.argtools import shape2d
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents']
class ImageFromFile(RNGDataFlow):
......@@ -91,6 +91,53 @@ class AugmentImageComponent(MapDataComponent):
self.augs.reset_state()
class AugmentImageCoordinates(MapData):
"""
Apply image augmentors on an image and set of coordinates.
"""
def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
img_index (int): the index of the image component to be augmented.
coords_index (int): the index of the coordinate component to be augmented.
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
Turn it off to save time when you know it's OK.
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors)
self._nr_error = 0
def func(dp):
try:
img, coords = dp[img_index], dp[coords_index]
if copy:
img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img)
dp[img_index] = img
coords = self.augs._augment_coords(coords, prms)
dp[coords_index] = coords
return dp
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None
super(AugmentImageCoordinates, self).__init__(ds, func)
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
class AugmentImageComponents(MapData):
"""
Apply image augmentors on several components, with shared augmentation parameters.
......@@ -146,4 +193,5 @@ except ImportError:
from ..utils.develop import create_dummy_class
ImageFromFile = create_dummy_class('ImageFromFile', 'cv2') # noqa
AugmentImageComponent = create_dummy_class('AugmentImageComponent', 'cv2') # noqa
AugmentImageCoordinates = create_dummy_class('AugmentImageCoordinates', 'cv2') # noqa
AugmentImageComponents = create_dummy_class('AugmentImageComponents', 'cv2') # noqa
......@@ -98,6 +98,11 @@ class AugmentorList(ImageAugmentor):
img = aug._augment(img, prm)
return img
def _augment_coords(self, coords, param):
for aug, prm in zip(self.augs, param):
coords = aug._augment_coords(coords, prm)
return coords
def reset_state(self):
""" Will reset state of each augmentor """
for a in self.augs:
......
......@@ -40,7 +40,10 @@ class RandomCrop(ImageAugmentor):
return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _augment_coords(self, coords, param):
raise NotImplementedError()
h0, w0 = param
coords[:, 0] = coords[:, 0] - w0
coords[:, 1] = coords[:, 1] - h0
return coords
class CenterCrop(ImageAugmentor):
......@@ -54,14 +57,21 @@ class CenterCrop(ImageAugmentor):
crop_shape = shape2d(crop_shape)
self._init(locals())
def _augment(self, img, _):
def _get_augment_params(self, img):
orig_shape = img.shape
h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5)
w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5)
return (h0, w0)
def _augment(self, img, param):
h0, w0 = param
return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _augment_coords(self, coords, param):
raise NotImplementedError()
h0, w0 = param
coords[:, 0] = coords[:, 0] - w0
coords[:, 1] = coords[:, 1] - h0
return coords
def perturb_BB(image_shape, bb, max_perturb_pixel,
......@@ -127,8 +137,10 @@ class RandomCropAroundBox(ImageAugmentor):
def _augment(self, img, newbox):
return newbox.roi(img)
def _augment_coords(self, coords, param):
raise NotImplementedError()
def _augment_coords(self, coords, newbox):
coords[:, 0] = coords[:, 0] - newbox.x0
coords[:, 1] = coords[:, 1] - newbox.y0
return coords
class RandomCropRandomShape(ImageAugmentor):
......@@ -165,6 +177,12 @@ class RandomCropRandomShape(ImageAugmentor):
y0, x0, h, w = param
return img[y0:y0 + h, x0:x0 + w]
def _augment_coords(self, coords, param):
y0, x0, _, _ = param
coords[:, 0] = coords[:, 0] - x0
coords[:, 1] = coords[:, 1] - y0
return coords
if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
......@@ -42,6 +42,9 @@ class Shift(ImageAugmentor):
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
raise NotImplementedError()
class Rotation(ImageAugmentor):
""" Random rotate the image w.r.t a random center"""
......@@ -79,6 +82,9 @@ class Rotation(ImageAugmentor):
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
raise NotImplementedError()
class RotationAndCropValid(ImageAugmentor):
""" Random rotate and then crop the largest possible rectangle.
......@@ -115,6 +121,9 @@ class RotationAndCropValid(ImageAugmentor):
# print(ret.shape, deg, newx, newy, neww, newh)
return ret[newy:newy + newh, newx:newx + neww]
def _augment_coords(self, coords, param):
raise NotImplementedError()
@staticmethod
def largest_rotated_rect(w, h, angle):
"""
......
......@@ -56,6 +56,12 @@ class RandomApplyAug(ImageAugmentor):
else:
return self.aug._augment(img, prm[1])
def _augment_coords(self, coords, prm):
if not prm[0]:
return coords
else:
return self.aug._augment_coords(coords, prm[1])
class RandomChooseAug(ImageAugmentor):
""" Randomly choose one from a list of augmentors """
......@@ -69,7 +75,7 @@ class RandomChooseAug(ImageAugmentor):
aug_lists = [k[0] for k in aug_lists]
self._init(locals())
else:
prob = 1.0 / len(aug_lists)
prob = [1.0 / len(aug_lists)] * len(aug_lists)
self._init(locals())
super(RandomChooseAug, self).__init__()
......@@ -87,6 +93,10 @@ class RandomChooseAug(ImageAugmentor):
idx, prm = prm
return self.aug_lists[idx]._augment(img, prm)
def _augment_coords(self, coords, prm):
idx, prm = prm
return self.aug_lists[idx]._augment_coords(coords, prm)
class RandomOrderAug(ImageAugmentor):
"""
......@@ -121,18 +131,30 @@ class RandomOrderAug(ImageAugmentor):
img = self.aug_lists[k]._augment(img, prms[k])
return img
def _augment_coords(self, coords, prm):
idxs, prms = prm
for k in idxs:
img = self.aug_lists[k]._augment_coords(coords, prms[k])
return img
class MapImage(ImageAugmentor):
"""
Map the image array by a function.
"""
def __init__(self, func):
def __init__(self, func, coord_func=None):
"""
Args:
func: a function which takes an image array and return an augmented one
"""
self.func = func
self.coord_func = coord_func
def _augment(self, img, _):
return self.func(img)
def _augment_coords(self, coords, _):
if self.coord_func is None:
raise NotImplementedError
return self.coord_func(coords)
......@@ -35,9 +35,12 @@ class Flip(ImageAugmentor):
self._init()
def _get_augment_params(self, img):
return self._rand_range() < self.prob
h, w = img.shape[:2]
do = self._rand_range() < self.prob
return (do, h, w)
def _augment(self, img, do):
def _augment(self, img, param):
do, _, _ = param
if do:
ret = cv2.flip(img, self.code)
if img.ndim == 3 and ret.ndim == 2:
......@@ -47,7 +50,13 @@ class Flip(ImageAugmentor):
return ret
def _augment_coords(self, coords, param):
raise NotImplementedError()
do, h, w = param
if do:
if self.code == 0:
coords[:, 1] = h - coords[:, 1]
elif self.code == 1:
coords[:, 0] = w - coords[:, 0]
return coords
class Resize(ImageAugmentor):
......@@ -62,6 +71,10 @@ class Resize(ImageAugmentor):
shape = tuple(shape2d(shape))
self._init(locals())
def _get_augment_params(self, img):
h, w = img.shape[:2]
return (h, w)
def _augment(self, img, _):
ret = cv2.resize(
img, self.shape[::-1],
......@@ -70,6 +83,12 @@ class Resize(ImageAugmentor):
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
h, w = param
coords[:, 0] = coords[:, 0] * self.shape[1] * 1.0 / w
coords[:, 1] = coords[:, 1] * self.shape[0] * 1.0 / h
return coords
class ResizeShortestEdge(ImageAugmentor):
"""
......@@ -85,15 +104,25 @@ class ResizeShortestEdge(ImageAugmentor):
size = size * 1.0
self._init(locals())
def _augment(self, img, _):
def _get_augment_params(self, img):
h, w = img.shape[:2]
scale = self.size / min(h, w)
desSize = map(int, [scale * w, scale * h])
ret = cv2.resize(img, tuple(desSize), interpolation=self.interp)
newh, neww = map(int, [scale * h, scale * w])
return (h, w, newh, neww)
def _augment(self, img, param):
_, _, newh, neww = param
ret = cv2.resize(img, (neww, newh), interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
h, w, newh, neww = param
coords[:, 0] = coords[:, 0] * neww * 1.0 / w
coords[:, 1] = coords[:, 1] * newh * 1.0 / h
return coords
class RandomResize(ImageAugmentor):
""" Randomly rescale w and h of the image"""
......@@ -117,30 +146,38 @@ class RandomResize(ImageAugmentor):
def _get_augment_params(self, img):
cnt = 0
h, w = img.shape[:2]
while True:
sx = self._rand_range(*self.xrange)
if self.aspect_ratio_thres == 0:
sy = sx
else:
sy = self._rand_range(*self.yrange)
destX = max(sx * img.shape[1], self.minimum[0])
destY = max(sy * img.shape[0], self.minimum[1])
oldr = img.shape[1] * 1.0 / img.shape[0]
destX = max(sx * w, self.minimum[0])
destY = max(sy * h, self.minimum[1])
oldr = w * 1.0 / h
newr = destX * 1.0 / destY
diff = abs(newr - oldr) / oldr
if diff <= self.aspect_ratio_thres + 1e-5:
return (int(destX), int(destY))
return (h, w, int(destY), int(destX))
cnt += 1
if cnt > 50:
logger.warn("RandomResize failed to augment an image")
return img.shape[1], img.shape[0]
return (h, w, h, w)
def _augment(self, img, dsize):
ret = cv2.resize(img, dsize, interpolation=self.interp)
def _augment(self, img, param):
_, _, newh, neww = param
ret = cv2.resize(img, (neww, newh), interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
h, w, newh, neww = param
coords[:, 0] = coords[:, 0] * neww * 1.0 / w
coords[:, 1] = coords[:, 1] * newh * 1.0 / h
return coords
class Transpose(ImageAugmentor):
"""
......@@ -166,5 +203,7 @@ class Transpose(ImageAugmentor):
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
raise NotImplementedError()
def _augment_coords(self, coords, do):
if do:
coords = coords[:, ::-1]
return coords
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment