Commit 2ecdbc00 authored by Yuxin Wu's avatar Yuxin Wu

RandomResize accept range as pixels (fix #343)

parent 6d2a0aa1
...@@ -125,14 +125,16 @@ class ResizeShortestEdge(ImageAugmentor): ...@@ -125,14 +125,16 @@ class ResizeShortestEdge(ImageAugmentor):
class RandomResize(ImageAugmentor): class RandomResize(ImageAugmentor):
""" Randomly rescale w and h of the image""" """ Randomly rescale width and height of the image."""
def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15, def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15,
interp=cv2.INTER_LINEAR): interp=cv2.INTER_LINEAR):
""" """
Args: Args:
xrange (tuple): (min, max) range of scaling ratio for w, e.g. (0.9, 1.2) xrange (tuple): a (min, max) tuple. If is floating point, the
yrange (tuple): (min, max) range of scaling ratio for h tuple defines the range of scaling ratio of new width, e.g. (0.9, 1.2).
If is integer, the tuple defines the range of new width in pixels, e.g. (200, 350).
yrange (tuple): similar to xrange, but for height.
minimum (tuple): (xmin, ymin) in pixels. To avoid scaling down too much. minimum (tuple): (xmin, ymin) in pixels. To avoid scaling down too much.
aspect_ratio_thres (float): discard samples which change aspect ratio aspect_ratio_thres (float): discard samples which change aspect ratio
larger than this threshold. Set to 0 to keep aspect ratio. larger than this threshold. Set to 0 to keep aspect ratio.
...@@ -144,10 +146,17 @@ class RandomResize(ImageAugmentor): ...@@ -144,10 +146,17 @@ class RandomResize(ImageAugmentor):
assert xrange == yrange assert xrange == yrange
self._init(locals()) self._init(locals())
def is_float(tp):
return isinstance(tp[0], float) or isinstance(tp[1], float)
assert is_float(xrange) == is_float(yrange), "xrange and yrange has different type!"
self._is_scale = is_float(xrange)
def _get_augment_params(self, img): def _get_augment_params(self, img):
cnt = 0 cnt = 0
h, w = img.shape[:2] h, w = img.shape[:2]
while True:
def get_dest_size():
if self._is_scale:
sx = self._rand_range(*self.xrange) sx = self._rand_range(*self.xrange)
if self.aspect_ratio_thres == 0: if self.aspect_ratio_thres == 0:
sy = sx sy = sx
...@@ -155,15 +164,29 @@ class RandomResize(ImageAugmentor): ...@@ -155,15 +164,29 @@ class RandomResize(ImageAugmentor):
sy = self._rand_range(*self.yrange) sy = self._rand_range(*self.yrange)
destX = max(sx * w, self.minimum[0]) destX = max(sx * w, self.minimum[0])
destY = max(sy * h, self.minimum[1]) destY = max(sy * h, self.minimum[1])
else:
sx = int(self._rand_range(*self.xrange))
if self.aspect_ratio_thres == 0:
sy = sx * 1.0 / w * h
else:
sy = self._rand_range(*self.yrange)
destX = max(sx, self.minimum[0])
destY = max(sy, self.minimum[1])
return (destX, destY)
while True:
destX, destY = get_dest_size()
if self.aspect_ratio_thres > 0: # don't check when thres == 0
oldr = w * 1.0 / h oldr = w * 1.0 / h
newr = destX * 1.0 / destY newr = destX * 1.0 / destY
diff = abs(newr - oldr) / oldr diff = abs(newr - oldr) / oldr
if diff <= self.aspect_ratio_thres + 1e-5: if diff >= self.aspect_ratio_thres + 1e-5:
return (h, w, int(destY), int(destX))
cnt += 1 cnt += 1
if cnt > 50: if cnt > 50:
logger.warn("RandomResize failed to augment an image") logger.warn("RandomResize failed to augment an image")
return (h, w, h, w) return (h, w, h, w)
continue
return (h, w, int(destY), int(destX))
def _augment(self, img, param): def _augment(self, img, param):
_, _, newh, neww = param _, _, newh, neww = param
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment