Commit 42f2c644 authored by Yuxin Wu's avatar Yuxin Wu

share the implementation of crop/resize augmentors.

parent 77cf6145
......@@ -2,17 +2,20 @@
# File: crop.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from six.moves import range
import numpy as np
from .base import ImageAugmentor
from ...utils.rect import IntBox
from ...utils.develop import log_deprecated
from ...utils.argtools import shape2d
from .transform import TransformAugmentorBase, CropTransform
from six.moves import range
import numpy as np
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropAroundBox', 'RandomCropRandomShape']
class RandomCrop(ImageAugmentor):
class RandomCrop(TransformAugmentorBase):
""" Randomly crop the image into a smaller one """
def __init__(self, crop_shape):
......@@ -32,20 +35,10 @@ class RandomCrop(ImageAugmentor):
h0 = 0 if diffh == 0 else self.rng.randint(diffh)
diffw = orig_shape[1] - self.crop_shape[1]
w0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (h0, w0)
def _augment(self, img, param):
h0, w0 = param
return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _augment_coords(self, coords, param):
h0, w0 = param
coords[:, 0] = coords[:, 0] - w0
coords[:, 1] = coords[:, 1] - h0
return coords
return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1])
class CenterCrop(ImageAugmentor):
class CenterCrop(TransformAugmentorBase):
""" Crop the image at the center"""
def __init__(self, crop_shape):
......@@ -60,17 +53,7 @@ class CenterCrop(ImageAugmentor):
orig_shape = img.shape
h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5)
w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5)
return (h0, w0)
def _augment(self, img, param):
h0, w0 = param
return img[h0:h0 + self.crop_shape[0], w0:w0 + self.crop_shape[1]]
def _augment_coords(self, coords, param):
h0, w0 = param
coords[:, 0] = coords[:, 0] - w0
coords[:, 1] = coords[:, 1] - h0
return coords
return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1])
def perturb_BB(image_shape, bb, max_perturb_pixel,
......@@ -108,7 +91,7 @@ def perturb_BB(image_shape, bb, max_perturb_pixel,
return bb
# TODO shouldn't include strange augmentors like this.
# TODO deprecated. shouldn't include strange augmentors like this.
class RandomCropAroundBox(ImageAugmentor):
"""
Crop a box around a bounding box by some random perturbation.
......@@ -122,6 +105,10 @@ class RandomCropAroundBox(ImageAugmentor):
max_aspect_ratio_diff (float): keep aspect ratio difference within the range
"""
super(RandomCropAroundBox, self).__init__()
log_deprecated(
"RandomCropAroundBox",
"It's neither common nor well-defined. Please implement something by yourself.",
"2017-11-30")
self._init(locals())
def _get_augment_params(self, img):
......@@ -141,7 +128,7 @@ class RandomCropAroundBox(ImageAugmentor):
return coords
class RandomCropRandomShape(ImageAugmentor):
class RandomCropRandomShape(TransformAugmentorBase):
""" Random crop with a random shape"""
def __init__(self, wmin, hmin,
......@@ -169,17 +156,7 @@ class RandomCropRandomShape(ImageAugmentor):
assert diffh >= 0 and diffw >= 0
y0 = 0 if diffh == 0 else self.rng.randint(diffh)
x0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (y0, x0, h, w)
def _augment(self, img, param):
y0, x0, h, w = param
return img[y0:y0 + h, x0:x0 + w]
def _augment_coords(self, coords, param):
y0, x0, _, _ = param
coords[:, 0] = coords[:, 0] - x0
coords[:, 1] = coords[:, 1] - y0
return coords
return CropTransform(y0, x0, h, w)
if __name__ == '__main__':
......
......@@ -2,11 +2,13 @@
# File: misc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import cv2
from .base import ImageAugmentor
from ...utils import logger
from ...utils.argtools import shape2d
import numpy as np
import cv2
from .transform import ResizeTransform, TransformAugmentorBase
__all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge', 'Transpose']
......@@ -59,7 +61,7 @@ class Flip(ImageAugmentor):
return coords
class Resize(ImageAugmentor):
class Resize(TransformAugmentorBase):
""" Resize image to a target size"""
def __init__(self, shape, interp=cv2.INTER_LINEAR):
......@@ -72,25 +74,12 @@ class Resize(ImageAugmentor):
self._init(locals())
def _get_augment_params(self, img):
h, w = img.shape[:2]
return (h, w)
def _augment(self, img, _):
ret = cv2.resize(
img, self.shape[::-1],
interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
return ResizeTransform(
img.shape[0], img.shape[1],
self.shape[0], self.shape[1], self.interp)
def _augment_coords(self, coords, param):
h, w = param
coords[:, 0] = coords[:, 0] * (self.shape[1] * 1.0 / w)
coords[:, 1] = coords[:, 1] * (self.shape[0] * 1.0 / h)
return coords
class ResizeShortestEdge(ImageAugmentor):
class ResizeShortestEdge(TransformAugmentorBase):
"""
Resize the shortest edge to a certain number while
keeping the aspect ratio.
......@@ -111,23 +100,11 @@ class ResizeShortestEdge(ImageAugmentor):
newh, neww = self.size, int(scale * w)
else:
newh, neww = int(scale * h), self.size
return (h, w, newh, neww)
def _augment(self, img, param):
_, _, newh, neww = param
ret = cv2.resize(img, (neww, newh), interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
h, w, newh, neww = param
coords[:, 0] = coords[:, 0] * (neww * 1.0 / w)
coords[:, 1] = coords[:, 1] * (newh * 1.0 / h)
return coords
return ResizeTransform(
h, w, newh, neww, self.interp)
class RandomResize(ImageAugmentor):
class RandomResize(TransformAugmentorBase):
""" Randomly rescale width and height of the image."""
def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15,
......@@ -187,22 +164,9 @@ class RandomResize(ImageAugmentor):
cnt += 1
if cnt > 50:
logger.warn("RandomResize failed to augment an image")
return (h, w, h, w)
return ResizeTransform(h, w, h, w, self.interp)
continue
return (h, w, int(destY), int(destX))
def _augment(self, img, param):
_, _, newh, neww = param
ret = cv2.resize(img, (neww, newh), interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
h, w, newh, neww = param
coords[:, 0] = coords[:, 0] * (neww * 1.0 / w)
coords[:, 1] = coords[:, 1] * (newh * 1.0 / h)
return coords
return ResizeTransform(h, w, int(destY), int(destX), self.interp)
class Transpose(ImageAugmentor):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: transform.py
from abc import abstractmethod, ABCMeta
import six
import cv2
import numpy as np
from .base import ImageAugmentor
__all__ = []
class TransformAugmentorBase(ImageAugmentor):
"""
Base class of augmentors which use :class:`ImageTransform`
for the actual implementation of the transformations.
It assumes that :meth:`_get_augment_params` should
return a :class:`ImageTransform` instance, and it will use
this instance to augment both image and coordinates.
"""
def _augment(self, img, t):
return t.apply_image(img)
def _augment_coords(self, coords, t):
return t.apply_coords(coords)
@six.add_metaclass(ABCMeta)
class ImageTransform(object):
"""
A deterministic image transformation, used to implement
the (probably random) augmentors.
This way the deterministic part
(the actual transformation which may be common between augmentors)
can be separated from the random part
(the random policy which is different between augmentors).
"""
def _init(self, params=None):
if params:
for k, v in params.items():
if k != 'self':
setattr(self, k, v)
@abstractmethod
def apply_image(self, img):
pass
@abstractmethod
def apply_coords(self, coords):
pass
class ResizeTransform(ImageTransform):
def __init__(self, h, w, newh, neww, interp):
self._init(locals())
def apply_image(self, img):
assert img.shape[:2] == (self.h, self.w)
ret = cv2.resize(
img, (self.neww, self.newh),
interpolation=self.interp)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def apply_coords(self, coords):
coords[:, 0] = coords[:, 0] * (self.neww * 1.0 / self.w)
coords[:, 1] = coords[:, 1] * (self.newh * 1.0 / self.h)
return coords
class CropTransform(ImageTransform):
def __init__(self, h0, w0, h, w):
self._init(locals())
def apply_image(self, img):
return img[self.h0:self.h0 + self.h, self.w0:self.w0 + self.w]
def apply_coords(self, coords):
coords[:, 0] -= self.w0
coords[:, 1] -= self.h0
return coords
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment