Commit 61127c2d authored by Yuxin Wu's avatar Yuxin Wu

IntBox and FloatBox

parent c9226e90
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor from .base import ImageAugmentor
from ...utils.rect import Rect from ...utils.rect import IntBox
from ...utils.argtools import shape2d from ...utils.argtools import shape2d
from six.moves import range from six.moves import range
...@@ -81,7 +81,7 @@ def perturb_BB(image_shape, bb, max_perturb_pixel, ...@@ -81,7 +81,7 @@ def perturb_BB(image_shape, bb, max_perturb_pixel,
Args: Args:
image_shape: [h, w] image_shape: [h, w]
bb (Rect): original bounding box bb (IntBox): original bounding box
max_perturb_pixel: perturbation on each coordinate max_perturb_pixel: perturbation on each coordinate
max_aspect_ratio_diff: result can't have an aspect ratio too different from the original max_aspect_ratio_diff: result can't have an aspect ratio too different from the original
max_try: if cannot find a valid bounding box, return the original max_try: if cannot find a valid bounding box, return the original
...@@ -94,13 +94,11 @@ def perturb_BB(image_shape, bb, max_perturb_pixel, ...@@ -94,13 +94,11 @@ def perturb_BB(image_shape, bb, max_perturb_pixel,
for _ in range(max_try): for _ in range(max_try):
p = rng.randint(-max_perturb_pixel, max_perturb_pixel, [4]) p = rng.randint(-max_perturb_pixel, max_perturb_pixel, [4])
newbb = bb.copy() newbb = bb.copy()
newbb.x += p[0] newbb.x1 += p[0]
newbb.y += p[1] newbb.y1 += p[1]
newx1 = bb.x1 + p[2] newbb.x2 = bb.x2 + p[2]
newy1 = bb.y1 + p[3] newbb.y2 = bb.y2 + p[3]
newbb.w = newx1 - newbb.x if not newbb.is_valid_box(image_shape):
newbb.h = newy1 - newbb.y
if not newbb.validate(image_shape):
continue continue
new_ratio = newbb.h * 1.0 / newbb.w new_ratio = newbb.h * 1.0 / newbb.w
diff = abs(new_ratio - orig_ratio) diff = abs(new_ratio - orig_ratio)
...@@ -128,7 +126,7 @@ class RandomCropAroundBox(ImageAugmentor): ...@@ -128,7 +126,7 @@ class RandomCropAroundBox(ImageAugmentor):
def _get_augment_params(self, img): def _get_augment_params(self, img):
shape = img.shape[:2] shape = img.shape[:2]
box = Rect(0, 0, shape[1] - 1, shape[0] - 1) box = IntBox(0, 0, shape[1] - 1, shape[0] - 1)
dist = self.perturb_ratio * np.sqrt(shape[0] * shape[1]) dist = self.perturb_ratio * np.sqrt(shape[0] * shape[1])
newbox = perturb_BB(shape, box, dist, newbox = perturb_BB(shape, box, dist,
self.rng, self.max_aspect_ratio_diff) self.rng, self.max_aspect_ratio_diff)
...@@ -138,8 +136,8 @@ class RandomCropAroundBox(ImageAugmentor): ...@@ -138,8 +136,8 @@ class RandomCropAroundBox(ImageAugmentor):
return newbox.roi(img) return newbox.roi(img)
def _augment_coords(self, coords, newbox): def _augment_coords(self, coords, newbox):
coords[:, 0] = coords[:, 0] - newbox.x0 coords[:, 0] = coords[:, 0] - newbox.x1
coords[:, 1] = coords[:, 1] - newbox.y0 coords[:, 1] = coords[:, 1] - newbox.y1
return coords return coords
...@@ -185,4 +183,4 @@ class RandomCropRandomShape(ImageAugmentor): ...@@ -185,4 +183,4 @@ class RandomCropRandomShape(ImageAugmentor):
if __name__ == '__main__': if __name__ == '__main__':
print(perturb_BB([100, 100], Rect(3, 3, 50, 50), 50)) print(perturb_BB([100, 100], IntBox(3, 3, 50, 50), 50))
...@@ -5,38 +5,17 @@ ...@@ -5,38 +5,17 @@
import numpy as np import numpy as np
__all__ = ['IntBox', 'FloatBox']
class Rect(object):
"""
A rectangle class.
Note that x1 = x + w, not x+w-1 or something else. class BoxBase(object):
""" __slots__ = ['x1', 'y1', 'x2', 'y2']
__slots__ = ['x', 'y', 'w', 'h']
def __init__(self, x=0, y=0, w=0, h=0, allow_neg=False): def __init__(self, x1, y1, x2, y2):
self.x = x self.x1 = x1
self.y = y self.y1 = y1
self.w = w self.x2 = x2
self.h = h self.y2 = y2
if not allow_neg:
assert min(self.x, self.y, self.w, self.h) >= 0
@property
def x0(self):
return self.x
@property
def y0(self):
return self.y
@property
def x1(self):
return self.x + self.w
@property
def y1(self):
return self.y + self.h
def copy(self): def copy(self):
new = type(self)() new = type(self)()
...@@ -45,66 +24,103 @@ class Rect(object): ...@@ -45,66 +24,103 @@ class Rect(object):
return new return new
def __str__(self): def __str__(self):
return 'Rect(x={}, y={}, w={}, h={})'.format(self.x, self.y, self.w, self.h) return '{}(x1={}, y1={}, x2={}, y2={})'.format(
type(self).__name__, self.x1, self.y1, self.x2, self.y2)
__repr__ = __str__
def area(self): def area(self):
return self.w * self.h return self.w * self.h
def validate(self, shape=None): def is_box(self):
return self.area() > 0
class IntBox(BoxBase):
def __init__(self, x1, y1, x2, y2):
for k in [x1, y1, x2, y2]:
assert isinstance(k, int)
super(IntBox, self).__init__(x1, y1, x2, y2)
@property
def w(self):
return self.x2 - self.x1 + 1
@property
def h(self):
return self.y2 - self.y1 + 1
def is_valid_box(self, shape):
""" """
Check that this rect is a valid bounding box within this shape. Check that this rect is a valid bounding box within this shape.
Args: Args:
shape: [h, w] shape: int [h, w] or None.
Returns: Returns:
bool bool
""" """
if min(self.x, self.y) < 0: if min(self.x1, self.y1) < 0:
return False return False
if min(self.w, self.h) <= 0: if min(self.w, self.h) <= 0:
return False return False
if shape is None: if self.x2 >= shape[1]:
return True
if self.x1 > shape[1] - 1:
return False return False
if self.y1 > shape[0] - 1: if self.y2 >= shape[0]:
return False return False
return True return True
def roi(self, img): def roi(self, img):
assert self.validate(img.shape[:2]), "{} vs {}".format(self, img.shape[:2]) assert self.validate(img.shape[:2]), "{} vs {}".format(self, img.shape[:2])
return img[self.y0:self.y1 + 1, self.x0:self.x1 + 1] return img[self.y1:self.y2 + 1, self.x1:self.x2 + 1]
def expand(self, frac): # def expand(self, frac):
assert frac > 1.0, frac # assert frac > 1.0, frac
neww = self.w * frac # neww = self.w * frac
newh = self.h * frac # newh = self.h * frac
newx = self.x - (neww - self.w) * 0.5 # newx = self.x - (neww - self.w) * 0.5
newy = self.y - (newh - self.h) * 0.5 # newy = self.y - (newh - self.h) * 0.5
return Rect(*(map(int, [newx, newy, neww, newh])), allow_neg=True) # return Rect(*(map(int, [newx, newy, neww, newh])), allow_neg=True)
def roi_zeropad(self, img): # def roi_zeropad(self, img):
shp = list(img.shape) # shp = list(img.shape)
shp[0] = self.h # shp[0] = self.h
shp[1] = self.w # shp[1] = self.w
ret = np.zeros(tuple(shp), dtype=img.dtype) # ret = np.zeros(tuple(shp), dtype=img.dtype)
xstart = 0 if self.x >= 0 else -self.x # xstart = 0 if self.x >= 0 else -self.x
ystart = 0 if self.y >= 0 else -self.y # ystart = 0 if self.y >= 0 else -self.y
xmin = max(self.x0, 0) # xmin = max(self.x0, 0)
ymin = max(self.y0, 0) # ymin = max(self.y0, 0)
xmax = min(self.x1, img.shape[1]) # xmax = min(self.x1, img.shape[1])
ymax = min(self.y1, img.shape[0]) # ymax = min(self.y1, img.shape[0])
patch = img[ymin:ymax, xmin:xmax] # patch = img[ymin:ymax, xmin:xmax]
ret[ystart:ystart + patch.shape[0], xstart:xstart + patch.shape[1]] = patch # ret[ystart:ystart + patch.shape[0], xstart:xstart + patch.shape[1]] = patch
return ret # return ret
class FloatBox(BoxBase):
def __init__(self, x1, y1, x2, y2):
for k in [x1, y1, x2, y2]:
assert isinstance(k, float)
super(FloatBox, self).__init__(x1, y1, x2, y2)
__repr__ = __str__ @property
def w(self):
return self.x2 - self.x1
@property
def h(self):
return self.y2 - self.y1
@staticmethod
def from_intbox(intbox):
return FloatBox(intbox.x1, intbox.y1,
intbox.x2 + 1, intbox.y2 + 1)
if __name__ == '__main__': if __name__ == '__main__':
x = Rect(2, 1, 3, 3, allow_neg=True) x = IntBox(2, 1, 3, 3)
img = np.random.rand(3, 3) img = np.random.rand(3, 3)
print(img) print(img)
print(x.roi_zeropad(img))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment