Commit faf00e48 authored by haamoon's avatar haamoon Committed by Yuxin Wu

add _augment_coords implementation for many ImageAugmentor subclasses (#335)

* add _augment_coords implementation for many ImageAugmentor subclasses

* Fxied style issue

* Fixed style issues

* Added AugmentImageCoordinates

* Fixed stype issues

* remove unused comments
parent 4c2bcb94
...@@ -9,7 +9,7 @@ from .common import MapDataComponent, MapData ...@@ -9,7 +9,7 @@ from .common import MapDataComponent, MapData
from ..utils import logger from ..utils import logger
from ..utils.argtools import shape2d from ..utils.argtools import shape2d
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents'] __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents']
class ImageFromFile(RNGDataFlow): class ImageFromFile(RNGDataFlow):
...@@ -91,6 +91,53 @@ class AugmentImageComponent(MapDataComponent): ...@@ -91,6 +91,53 @@ class AugmentImageComponent(MapDataComponent):
self.augs.reset_state() self.augs.reset_state()
class AugmentImageCoordinates(MapData):
"""
Apply image augmentors on an image and set of coordinates.
"""
def __init__(self, ds, augmentors, img_index=0, coords_index=1, copy=True):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
img_index (int): the index of the image component to be augmented.
coords_index (int): the index of the coordinate component to be augmented.
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
Turn it off to save time when you know it's OK.
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors)
self._nr_error = 0
def func(dp):
try:
img, coords = dp[img_index], dp[coords_index]
if copy:
img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img)
dp[img_index] = img
coords = self.augs._augment_coords(coords, prms)
dp[coords_index] = coords
return dp
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0 or self._nr_error < 10:
logger.exception("Got {} augmentation errors.".format(self._nr_error))
return None
super(AugmentImageCoordinates, self).__init__(ds, func)
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
class AugmentImageComponents(MapData): class AugmentImageComponents(MapData):
""" """
Apply image augmentors on several components, with shared augmentation parameters. Apply image augmentors on several components, with shared augmentation parameters.
...@@ -146,4 +193,5 @@ except ImportError: ...@@ -146,4 +193,5 @@ except ImportError:
from ..utils.develop import create_dummy_class from ..utils.develop import create_dummy_class
ImageFromFile = create_dummy_class('ImageFromFile', 'cv2') # noqa ImageFromFile = create_dummy_class('ImageFromFile', 'cv2') # noqa
AugmentImageComponent = create_dummy_class('AugmentImageComponent', 'cv2') # noqa AugmentImageComponent = create_dummy_class('AugmentImageComponent', 'cv2') # noqa
AugmentImageCoordinates = create_dummy_class('AugmentImageCoordinates', 'cv2') # noqa
AugmentImageComponents = create_dummy_class('AugmentImageComponents', 'cv2') # noqa AugmentImageComponents = create_dummy_class('AugmentImageComponents', 'cv2') # noqa
...@@ -98,6 +98,11 @@ class AugmentorList(ImageAugmentor): ...@@ -98,6 +98,11 @@ class AugmentorList(ImageAugmentor):
img = aug._augment(img, prm) img = aug._augment(img, prm)
return img return img
def _augment_coords(self, coords, param):
for aug, prm in zip(self.augs, param):
coords = aug._augment_coords(coords, prm)
return coords
def reset_state(self): def reset_state(self):
""" Will reset state of each augmentor """ """ Will reset state of each augmentor """
for a in self.augs: for a in self.augs:
......
...@@ -40,7 +40,10 @@ class RandomCrop(ImageAugmentor): ...@@ -40,7 +40,10 @@ class RandomCrop(ImageAugmentor):
return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]] return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _augment_coords(self, coords, param): def _augment_coords(self, coords, param):
raise NotImplementedError() h0, w0 = param
coords[:, 0] = coords[:, 0] - w0
coords[:, 1] = coords[:, 1] - h0
return coords
class CenterCrop(ImageAugmentor): class CenterCrop(ImageAugmentor):
...@@ -54,14 +57,21 @@ class CenterCrop(ImageAugmentor): ...@@ -54,14 +57,21 @@ class CenterCrop(ImageAugmentor):
crop_shape = shape2d(crop_shape) crop_shape = shape2d(crop_shape)
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _get_augment_params(self, img):
orig_shape = img.shape orig_shape = img.shape
h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5) h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5)
w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5) w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5)
return (h0, w0)
def _augment(self, img, param):
h0, w0 = param
return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]] return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _augment_coords(self, coords, param): def _augment_coords(self, coords, param):
raise NotImplementedError() h0, w0 = param
coords[:, 0] = coords[:, 0] - w0
coords[:, 1] = coords[:, 1] - h0
return coords
def perturb_BB(image_shape, bb, max_perturb_pixel, def perturb_BB(image_shape, bb, max_perturb_pixel,
...@@ -127,8 +137,10 @@ class RandomCropAroundBox(ImageAugmentor): ...@@ -127,8 +137,10 @@ class RandomCropAroundBox(ImageAugmentor):
def _augment(self, img, newbox): def _augment(self, img, newbox):
return newbox.roi(img) return newbox.roi(img)
def _augment_coords(self, coords, param): def _augment_coords(self, coords, newbox):
raise NotImplementedError() coords[:, 0] = coords[:, 0] - newbox.x0
coords[:, 1] = coords[:, 1] - newbox.y0
return coords
class RandomCropRandomShape(ImageAugmentor): class RandomCropRandomShape(ImageAugmentor):
...@@ -165,6 +177,12 @@ class RandomCropRandomShape(ImageAugmentor): ...@@ -165,6 +177,12 @@ class RandomCropRandomShape(ImageAugmentor):
y0, x0, h, w = param y0, x0, h, w = param
return img[y0:y0 + h, x0:x0 + w] return img[y0:y0 + h, x0:x0 + w]
def _augment_coords(self, coords, param):
y0, x0, _, _ = param
coords[:, 0] = coords[:, 0] - x0
coords[:, 1] = coords[:, 1] - y0
return coords
if __name__ == '__main__': if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50)) print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50))
...@@ -42,6 +42,9 @@ class Shift(ImageAugmentor): ...@@ -42,6 +42,9 @@ class Shift(ImageAugmentor):
ret = ret[:, :, np.newaxis] ret = ret[:, :, np.newaxis]
return ret return ret
def _augment_coords(self, coords, param):
raise NotImplementedError()
class Rotation(ImageAugmentor): class Rotation(ImageAugmentor):
""" Random rotate the image w.r.t a random center""" """ Random rotate the image w.r.t a random center"""
...@@ -79,6 +82,9 @@ class Rotation(ImageAugmentor): ...@@ -79,6 +82,9 @@ class Rotation(ImageAugmentor):
ret = ret[:, :, np.newaxis] ret = ret[:, :, np.newaxis]
return ret return ret
def _augment_coords(self, coords, param):
raise NotImplementedError()
class RotationAndCropValid(ImageAugmentor): class RotationAndCropValid(ImageAugmentor):
""" Random rotate and then crop the largest possible rectangle. """ Random rotate and then crop the largest possible rectangle.
...@@ -115,6 +121,9 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -115,6 +121,9 @@ class RotationAndCropValid(ImageAugmentor):
# print(ret.shape, deg, newx, newy, neww, newh) # print(ret.shape, deg, newx, newy, neww, newh)
return ret[newy:newy + newh, newx:newx + neww] return ret[newy:newy + newh, newx:newx + neww]
def _augment_coords(self, coords, param):
raise NotImplementedError()
@staticmethod @staticmethod
def largest_rotated_rect(w, h, angle): def largest_rotated_rect(w, h, angle):
""" """
......
...@@ -56,6 +56,12 @@ class RandomApplyAug(ImageAugmentor): ...@@ -56,6 +56,12 @@ class RandomApplyAug(ImageAugmentor):
else: else:
return self.aug._augment(img, prm[1]) return self.aug._augment(img, prm[1])
def _augment_coords(self, coords, prm):
if not prm[0]:
return coords
else:
return self.aug._augment_coords(coords, prm[1])
class RandomChooseAug(ImageAugmentor): class RandomChooseAug(ImageAugmentor):
""" Randomly choose one from a list of augmentors """ """ Randomly choose one from a list of augmentors """
...@@ -69,7 +75,7 @@ class RandomChooseAug(ImageAugmentor): ...@@ -69,7 +75,7 @@ class RandomChooseAug(ImageAugmentor):
aug_lists = [k[0] for k in aug_lists] aug_lists = [k[0] for k in aug_lists]
self._init(locals()) self._init(locals())
else: else:
prob = 1.0 / len(aug_lists) prob = [1.0 / len(aug_lists)] * len(aug_lists)
self._init(locals()) self._init(locals())
super(RandomChooseAug, self).__init__() super(RandomChooseAug, self).__init__()
...@@ -87,6 +93,10 @@ class RandomChooseAug(ImageAugmentor): ...@@ -87,6 +93,10 @@ class RandomChooseAug(ImageAugmentor):
idx, prm = prm idx, prm = prm
return self.aug_lists[idx]._augment(img, prm) return self.aug_lists[idx]._augment(img, prm)
def _augment_coords(self, coords, prm):
idx, prm = prm
return self.aug_lists[idx]._augment_coords(coords, prm)
class RandomOrderAug(ImageAugmentor): class RandomOrderAug(ImageAugmentor):
""" """
...@@ -121,18 +131,30 @@ class RandomOrderAug(ImageAugmentor): ...@@ -121,18 +131,30 @@ class RandomOrderAug(ImageAugmentor):
img = self.aug_lists[k]._augment(img, prms[k]) img = self.aug_lists[k]._augment(img, prms[k])
return img return img
def _augment_coords(self, coords, prm):
idxs, prms = prm
for k in idxs:
img = self.aug_lists[k]._augment_coords(coords, prms[k])
return img
class MapImage(ImageAugmentor): class MapImage(ImageAugmentor):
""" """
Map the image array by a function. Map the image array by a function.
""" """
def __init__(self, func): def __init__(self, func, coord_func=None):
""" """
Args: Args:
func: a function which takes an image array and return an augmented one func: a function which takes an image array and return an augmented one
""" """
self.func = func self.func = func
self.coord_func = coord_func
def _augment(self, img, _): def _augment(self, img, _):
return self.func(img) return self.func(img)
def _augment_coords(self, coords, _):
if self.coord_func is None:
raise NotImplementedError
return self.coord_func(coords)
...@@ -35,9 +35,12 @@ class Flip(ImageAugmentor): ...@@ -35,9 +35,12 @@ class Flip(ImageAugmentor):
self._init() self._init()
def _get_augment_params(self, img): def _get_augment_params(self, img):
return self._rand_range() < self.prob h, w = img.shape[:2]
do = self._rand_range() < self.prob
return (do, h, w)
def _augment(self, img, do): def _augment(self, img, param):
do, _, _ = param
if do: if do:
ret = cv2.flip(img, self.code) ret = cv2.flip(img, self.code)
if img.ndim == 3 and ret.ndim == 2: if img.ndim == 3 and ret.ndim == 2:
...@@ -47,7 +50,13 @@ class Flip(ImageAugmentor): ...@@ -47,7 +50,13 @@ class Flip(ImageAugmentor):
return ret return ret
def _augment_coords(self, coords, param): def _augment_coords(self, coords, param):
raise NotImplementedError() do, h, w = param
if do:
if self.code == 0:
coords[:, 1] = h - coords[:, 1]
elif self.code == 1:
coords[:, 0] = w - coords[:, 0]
return coords
class Resize(ImageAugmentor): class Resize(ImageAugmentor):
...@@ -62,6 +71,10 @@ class Resize(ImageAugmentor): ...@@ -62,6 +71,10 @@ class Resize(ImageAugmentor):
shape = tuple(shape2d(shape)) shape = tuple(shape2d(shape))
self._init(locals()) self._init(locals())
def _get_augment_params(self, img):
h, w = img.shape[:2]
return (h, w)
def _augment(self, img, _): def _augment(self, img, _):
ret = cv2.resize( ret = cv2.resize(
img, self.shape[::-1], img, self.shape[::-1],
...@@ -70,6 +83,12 @@ class Resize(ImageAugmentor): ...@@ -70,6 +83,12 @@ class Resize(ImageAugmentor):
ret = ret[:, :, np.newaxis] ret = ret[:, :, np.newaxis]
return ret return ret
def _augment_coords(self, coords, param):
h, w = param
coords[:, 0] = coords[:, 0] * self.shape[1] * 1.0 / w
coords[:, 1] = coords[:, 1] * self.shape[0] * 1.0 / h
return coords
class ResizeShortestEdge(ImageAugmentor): class ResizeShortestEdge(ImageAugmentor):
""" """
...@@ -85,15 +104,25 @@ class ResizeShortestEdge(ImageAugmentor): ...@@ -85,15 +104,25 @@ class ResizeShortestEdge(ImageAugmentor):
size = size * 1.0 size = size * 1.0
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _get_augment_params(self, img):
h, w = img.shape[:2] h, w = img.shape[:2]
scale = self.size / min(h, w) scale = self.size / min(h, w)
desSize = map(int, [scale * w, scale * h]) newh, neww = map(int, [scale * h, scale * w])
ret = cv2.resize(img, tuple(desSize), interpolation=self.interp) return (h, w, newh, neww)
def _augment(self, img, param):
_, _, newh, neww = param
ret = cv2.resize(img, (neww, newh), interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2: if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis] ret = ret[:, :, np.newaxis]
return ret return ret
def _augment_coords(self, coords, param):
h, w, newh, neww = param
coords[:, 0] = coords[:, 0] * neww * 1.0 / w
coords[:, 1] = coords[:, 1] * newh * 1.0 / h
return coords
class RandomResize(ImageAugmentor): class RandomResize(ImageAugmentor):
""" Randomly rescale w and h of the image""" """ Randomly rescale w and h of the image"""
...@@ -117,30 +146,38 @@ class RandomResize(ImageAugmentor): ...@@ -117,30 +146,38 @@ class RandomResize(ImageAugmentor):
def _get_augment_params(self, img): def _get_augment_params(self, img):
cnt = 0 cnt = 0
h, w = img.shape[:2]
while True: while True:
sx = self._rand_range(*self.xrange) sx = self._rand_range(*self.xrange)
if self.aspect_ratio_thres == 0: if self.aspect_ratio_thres == 0:
sy = sx sy = sx
else: else:
sy = self._rand_range(*self.yrange) sy = self._rand_range(*self.yrange)
destX = max(sx * img.shape[1], self.minimum[0]) destX = max(sx * w, self.minimum[0])
destY = max(sy * img.shape[0], self.minimum[1]) destY = max(sy * h, self.minimum[1])
oldr = img.shape[1] * 1.0 / img.shape[0] oldr = w * 1.0 / h
newr = destX * 1.0 / destY newr = destX * 1.0 / destY
diff = abs(newr - oldr) / oldr diff = abs(newr - oldr) / oldr
if diff <= self.aspect_ratio_thres + 1e-5: if diff <= self.aspect_ratio_thres + 1e-5:
return (int(destX), int(destY)) return (h, w, int(destY), int(destX))
cnt += 1 cnt += 1
if cnt > 50: if cnt > 50:
logger.warn("RandomResize failed to augment an image") logger.warn("RandomResize failed to augment an image")
return img.shape[1], img.shape[0] return (h, w, h, w)
def _augment(self, img, dsize): def _augment(self, img, param):
ret = cv2.resize(img, dsize, interpolation=self.interp) _, _, newh, neww = param
ret = cv2.resize(img, (neww, newh), interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2: if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis] ret = ret[:, :, np.newaxis]
return ret return ret
def _augment_coords(self, coords, param):
h, w, newh, neww = param
coords[:, 0] = coords[:, 0] * neww * 1.0 / w
coords[:, 1] = coords[:, 1] * newh * 1.0 / h
return coords
class Transpose(ImageAugmentor): class Transpose(ImageAugmentor):
""" """
...@@ -166,5 +203,7 @@ class Transpose(ImageAugmentor): ...@@ -166,5 +203,7 @@ class Transpose(ImageAugmentor):
ret = ret[:, :, np.newaxis] ret = ret[:, :, np.newaxis]
return ret return ret
def _augment_coords(self, coords, param): def _augment_coords(self, coords, do):
raise NotImplementedError() if do:
coords = coords[:, ::-1]
return coords
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment