Commit fbeed06b authored by Yuxin Wu's avatar Yuxin Wu

Merge duplicated rotation implementation by using Transform

parent 16245e35
...@@ -3,15 +3,17 @@ ...@@ -3,15 +3,17 @@
# File: geometry.py # File: geometry.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ImageAugmentor
import math import math
import cv2 import cv2
import numpy as np import numpy as np
from .base import ImageAugmentor
from .transform import TransformAugmentorBase, WarpAffineTransform
__all__ = ['Shift', 'Rotation', 'RotationAndCropValid', 'Affine'] __all__ = ['Shift', 'Rotation', 'RotationAndCropValid', 'Affine']
class Shift(ImageAugmentor): class Shift(TransformAugmentorBase):
""" Random horizontal and vertical shifts """ """ Random horizontal and vertical shifts """
def __init__(self, horiz_frac=0, vert_frac=0, def __init__(self, horiz_frac=0, vert_frac=0,
...@@ -32,21 +34,12 @@ class Shift(ImageAugmentor): ...@@ -32,21 +34,12 @@ class Shift(ImageAugmentor):
max_dy = self.vert_frac * img.shape[0] max_dy = self.vert_frac * img.shape[0]
dx = np.round(self._rand_range(-max_dx, max_dx)) dx = np.round(self._rand_range(-max_dx, max_dx))
dy = np.round(self._rand_range(-max_dy, max_dy)) dy = np.round(self._rand_range(-max_dy, max_dy))
return np.float32(
[[1, 0, dx], [0, 1, dy]])
def _augment(self, img, shift_m):
ret = cv2.warpAffine(img, shift_m, img.shape[1::-1],
borderMode=self.border, borderValue=self.border_value)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param): mat = np.array([[1, 0, dx], [0, 1, dy]], dtype='float32')
raise NotImplementedError() return WarpAffineTransform(mat, img.shape[1::-1], self.border, self.border_value)
class Rotation(ImageAugmentor): class Rotation(TransformAugmentorBase):
""" Random rotate the image w.r.t a random center""" """ Random rotate the image w.r.t a random center"""
def __init__(self, max_deg, center_range=(0, 1), def __init__(self, max_deg, center_range=(0, 1),
...@@ -73,17 +66,21 @@ class Rotation(ImageAugmentor): ...@@ -73,17 +66,21 @@ class Rotation(ImageAugmentor):
deg = self._rand_range(-self.max_deg, self.max_deg) deg = self._rand_range(-self.max_deg, self.max_deg)
if self.step_deg: if self.step_deg:
deg = deg // self.step_deg * self.step_deg deg = deg // self.step_deg * self.step_deg
return cv2.getRotationMatrix2D(tuple(center - 0.5), deg, 1) """
The correct center is shape*0.5-0.5 This can be verified by:
def _augment(self, img, rot_m):
ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], SHAPE = 7
flags=self.interp, borderMode=self.border, borderValue=self.border_value) arr = np.random.rand(SHAPE, SHAPE)
if img.ndim == 3 and ret.ndim == 2: orig = arr
ret = ret[:, :, np.newaxis] c = SHAPE * 0.5 - 0.5
return ret c = (c, c)
for k in range(4):
def _augment_coords(self, coords, param): mat = cv2.getRotationMatrix2D(c, 90, 1)
raise NotImplementedError() arr = cv2.warpAffine(arr, mat, arr.shape)
assert np.all(arr == orig)
"""
mat = cv2.getRotationMatrix2D(tuple(center - 0.5), deg, 1)
return WarpAffineTransform(mat, img.shape[1::-1], self.border, self.border_value)
class RotationAndCropValid(ImageAugmentor): class RotationAndCropValid(ImageAugmentor):
...@@ -152,7 +149,7 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -152,7 +149,7 @@ class RotationAndCropValid(ImageAugmentor):
return int(np.round(wr)), int(np.round(hr)) return int(np.round(wr)), int(np.round(hr))
class Affine(ImageAugmentor): class Affine(TransformAugmentorBase):
""" """
Random affine transform of the image w.r.t to the image center. Random affine transform of the image w.r.t to the image center.
Transformations involve: Transformations involve:
...@@ -237,14 +234,5 @@ class Affine(ImageAugmentor): ...@@ -237,14 +234,5 @@ class Affine(ImageAugmentor):
# Apply shift : # Apply shift :
transform_matrix[0, 2] += dx transform_matrix[0, 2] += dx
transform_matrix[1, 2] += dy transform_matrix[1, 2] += dy
return transform_matrix return WarpAffineTransform(transform_matrix, img.shape[1::-1],
self.interp, self.border, self.border_value)
def _augment(self, img, transform_matrix):
ret = cv2.warpAffine(img, transform_matrix, img.shape[1::-1],
flags=self.interp, borderMode=self.border, borderValue=self.border_value)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, param):
raise NotImplementedError()
...@@ -85,3 +85,22 @@ class CropTransform(ImageTransform): ...@@ -85,3 +85,22 @@ class CropTransform(ImageTransform):
coords[:, 0] -= self.w0 coords[:, 0] -= self.w0
coords[:, 1] -= self.h0 coords[:, 1] -= self.h0
return coords return coords
class WarpAffineTransform(ImageTransform):
def __init__(self, mat, dsize, interp=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT, borderValue=0):
self._init(locals())
def apply_image(self, img):
ret = cv2.warpAffine(img, self.mat, self.dsize,
flags=self.interp,
borderMode=self.borderMode,
borderValue=self.borderValue)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def apply_coords(self, coords):
# TODO
raise NotImplementedError()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment