Commit 42f2c644 authored by Yuxin Wu's avatar Yuxin Wu

share the implementation of crop/resize augmentors.

parent 77cf6145
...@@ -2,17 +2,20 @@ ...@@ -2,17 +2,20 @@
# File: crop.py # File: crop.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from six.moves import range
import numpy as np
from .base import ImageAugmentor from .base import ImageAugmentor
from ...utils.rect import IntBox from ...utils.rect import IntBox
from ...utils.develop import log_deprecated
from ...utils.argtools import shape2d from ...utils.argtools import shape2d
from .transform import TransformAugmentorBase, CropTransform
from six.moves import range
import numpy as np
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropAroundBox', 'RandomCropRandomShape'] __all__ = ['RandomCrop', 'CenterCrop', 'RandomCropAroundBox', 'RandomCropRandomShape']
class RandomCrop(ImageAugmentor): class RandomCrop(TransformAugmentorBase):
""" Randomly crop the image into a smaller one """ """ Randomly crop the image into a smaller one """
def __init__(self, crop_shape): def __init__(self, crop_shape):
...@@ -32,20 +35,10 @@ class RandomCrop(ImageAugmentor): ...@@ -32,20 +35,10 @@ class RandomCrop(ImageAugmentor):
h0 = 0 if diffh == 0 else self.rng.randint(diffh) h0 = 0 if diffh == 0 else self.rng.randint(diffh)
diffw = orig_shape[1] - self.crop_shape[1] diffw = orig_shape[1] - self.crop_shape[1]
w0 = 0 if diffw == 0 else self.rng.randint(diffw) w0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (h0, w0) return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1])
def _augment(self, img, param):
h0, w0 = param
return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _augment_coords(self, coords, param):
h0, w0 = param
coords[:, 0] = coords[:, 0] - w0
coords[:, 1] = coords[:, 1] - h0
return coords
class CenterCrop(ImageAugmentor): class CenterCrop(TransformAugmentorBase):
""" Crop the image at the center""" """ Crop the image at the center"""
def __init__(self, crop_shape): def __init__(self, crop_shape):
...@@ -60,17 +53,7 @@ class CenterCrop(ImageAugmentor): ...@@ -60,17 +53,7 @@ class CenterCrop(ImageAugmentor):
orig_shape = img.shape orig_shape = img.shape
h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5) h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5)
w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5) w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5)
return (h0, w0) return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1])
def _augment(self, img, param):
h0, w0 = param
return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _augment_coords(self, coords, param):
h0, w0 = param
coords[:, 0] = coords[:, 0] - w0
coords[:, 1] = coords[:, 1] - h0
return coords
def perturb_BB(image_shape, bb, max_perturb_pixel, def perturb_BB(image_shape, bb, max_perturb_pixel,
...@@ -108,7 +91,7 @@ def perturb_BB(image_shape, bb, max_perturb_pixel, ...@@ -108,7 +91,7 @@ def perturb_BB(image_shape, bb, max_perturb_pixel,
return bb return bb
# TODO shouldn't include strange augmentors like this. # TODO deprecated. shouldn't include strange augmentors like this.
class RandomCropAroundBox(ImageAugmentor): class RandomCropAroundBox(ImageAugmentor):
""" """
Crop a box around a bounding box by some random perturbation. Crop a box around a bounding box by some random perturbation.
...@@ -122,6 +105,10 @@ class RandomCropAroundBox(ImageAugmentor): ...@@ -122,6 +105,10 @@ class RandomCropAroundBox(ImageAugmentor):
max_aspect_ratio_diff (float): keep aspect ratio difference within the range max_aspect_ratio_diff (float): keep aspect ratio difference within the range
""" """
super(RandomCropAroundBox, self).__init__() super(RandomCropAroundBox, self).__init__()
log_deprecated(
"RandomCropAroundBox",
"It's neither common nor well-defined. Please implement something by yourself.",
"2017-11-30")
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
...@@ -141,7 +128,7 @@ class RandomCropAroundBox(ImageAugmentor): ...@@ -141,7 +128,7 @@ class RandomCropAroundBox(ImageAugmentor):
return coords return coords
class RandomCropRandomShape(ImageAugmentor): class RandomCropRandomShape(TransformAugmentorBase):
""" Random crop with a random shape""" """ Random crop with a random shape"""
def __init__(self, wmin, hmin, def __init__(self, wmin, hmin,
...@@ -169,17 +156,7 @@ class RandomCropRandomShape(ImageAugmentor): ...@@ -169,17 +156,7 @@ class RandomCropRandomShape(ImageAugmentor):
assert diffh >= 0 and diffw >= 0 assert diffh >= 0 and diffw >= 0
y0 = 0 if diffh == 0 else self.rng.randint(diffh) y0 = 0 if diffh == 0 else self.rng.randint(diffh)
x0 = 0 if diffw == 0 else self.rng.randint(diffw) x0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (y0, x0, h, w) return CropTransform(y0, x0, h, w)
def _augment(self, img, param):
y0, x0, h, w = param
return img[y0:y0 + h, x0:x0 + w]
def _augment_coords(self, coords, param):
y0, x0, _, _ = param
coords[:, 0] = coords[:, 0] - x0
coords[:, 1] = coords[:, 1] - y0
return coords
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -2,11 +2,13 @@ ...@@ -2,11 +2,13 @@
# File: misc.py # File: misc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import cv2
from .base import ImageAugmentor from .base import ImageAugmentor
from ...utils import logger from ...utils import logger
from ...utils.argtools import shape2d from ...utils.argtools import shape2d
import numpy as np from .transform import ResizeTransform, TransformAugmentorBase
import cv2
__all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge', 'Transpose'] __all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge', 'Transpose']
...@@ -59,7 +61,7 @@ class Flip(ImageAugmentor): ...@@ -59,7 +61,7 @@ class Flip(ImageAugmentor):
return coords return coords
class Resize(ImageAugmentor): class Resize(TransformAugmentorBase):
""" Resize image to a target size""" """ Resize image to a target size"""
def __init__(self, shape, interp=cv2.INTER_LINEAR): def __init__(self, shape, interp=cv2.INTER_LINEAR):
...@@ -72,25 +74,12 @@ class Resize(ImageAugmentor): ...@@ -72,25 +74,12 @@ class Resize(ImageAugmentor):
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
h, w = img.shape[:2] return ResizeTransform(
return (h, w) img.shape[0], img.shape[1],
self.shape[0], self.shape[1], self.interp)
def _augment(self, img, _):
ret = cv2.resize(
img, self.shape[::-1],
interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
h, w = param
coords[:, 0] = coords[:, 0] * (self.shape[1] * 1.0 / w)
coords[:, 1] = coords[:, 1] * (self.shape[0] * 1.0 / h)
return coords
class ResizeShortestEdge(TransformAugmentorBase):
class ResizeShortestEdge(ImageAugmentor):
""" """
Resize the shortest edge to a certain number while Resize the shortest edge to a certain number while
keeping the aspect ratio. keeping the aspect ratio.
...@@ -111,23 +100,11 @@ class ResizeShortestEdge(ImageAugmentor): ...@@ -111,23 +100,11 @@ class ResizeShortestEdge(ImageAugmentor):
newh, neww = self.size, int(scale * w) newh, neww = self.size, int(scale * w)
else: else:
newh, neww = int(scale * h), self.size newh, neww = int(scale * h), self.size
return (h, w, newh, neww) return ResizeTransform(
h, w, newh, neww, self.interp)
def _augment(self, img, param):
_, _, newh, neww = param
ret = cv2.resize(img, (neww, newh), interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
h, w, newh, neww = param
coords[:, 0] = coords[:, 0] * (neww * 1.0 / w)
coords[:, 1] = coords[:, 1] * (newh * 1.0 / h)
return coords
class RandomResize(ImageAugmentor): class RandomResize(TransformAugmentorBase):
""" Randomly rescale width and height of the image.""" """ Randomly rescale width and height of the image."""
def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15, def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15,
...@@ -187,22 +164,9 @@ class RandomResize(ImageAugmentor): ...@@ -187,22 +164,9 @@ class RandomResize(ImageAugmentor):
cnt += 1 cnt += 1
if cnt > 50: if cnt > 50:
logger.warn("RandomResize failed to augment an image") logger.warn("RandomResize failed to augment an image")
return (h, w, h, w) return ResizeTransform(h, w, h, w, self.interp)
continue continue
return (h, w, int(destY), int(destX)) return ResizeTransform(h, w, int(destY), int(destX), self.interp)
def _augment(self, img, param):
_, _, newh, neww = param
ret = cv2.resize(img, (neww, newh), interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
h, w, newh, neww = param
coords[:, 0] = coords[:, 0] * (neww * 1.0 / w)
coords[:, 1] = coords[:, 1] * (newh * 1.0 / h)
return coords
class Transpose(ImageAugmentor): class Transpose(ImageAugmentor):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: transform.py
from abc import abstractmethod, ABCMeta
import six
import cv2
import numpy as np
from .base import ImageAugmentor
__all__ = []
class TransformAugmentorBase(ImageAugmentor):
"""
Base class of augmentors which use :class:`ImageTransform`
for the actual implementation of the transformations.
It assumes that :meth:`_get_augment_params` should
return a :class:`ImageTransform` instance, and it will use
this instance to augment both image and coordinates.
"""
def _augment(self, img, t):
return t.apply_image(img)
def _augment_coords(self, coords, t):
return t.apply_coords(coords)
@six.add_metaclass(ABCMeta)
class ImageTransform(object):
"""
A deterministic image transformation, used to implement
the (probably random) augmentors.
This way the deterministic part
(the actual transformation which may be common between augmentors)
can be separated from the random part
(the random policy which is different between augmentors).
"""
def _init(self, params=None):
if params:
for k, v in params.items():
if k != 'self':
setattr(self, k, v)
@abstractmethod
def apply_image(self, img):
pass
@abstractmethod
def apply_coords(self, coords):
pass
class ResizeTransform(ImageTransform):
def __init__(self, h, w, newh, neww, interp):
self._init(locals())
def apply_image(self, img):
assert img.shape[:2] == (self.h, self.w)
ret = cv2.resize(
img, (self.neww, self.newh),
interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def apply_coords(self, coords):
coords[:, 0] = coords[:, 0] * (self.neww * 1.0 / self.w)
coords[:, 1] = coords[:, 1] * (self.newh * 1.0 / self.h)
return coords
class CropTransform(ImageTransform):
def __init__(self, h0, w0, h, w):
self._init(locals())
def apply_image(self, img):
return img[self.h0:self.h0 + self.h, self.w0:self.w0 + self.w]
def apply_coords(self, coords):
coords[:, 0] -= self.w0
coords[:, 1] -= self.h0
return coords
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment