Commit 2ecdbc00 authored by Yuxin Wu's avatar Yuxin Wu

RandomResize accept range as pixels (fix #343)

parent 6d2a0aa1
......@@ -125,14 +125,16 @@ class ResizeShortestEdge(ImageAugmentor):
class RandomResize(ImageAugmentor):
""" Randomly rescale w and h of the image"""
""" Randomly rescale width and height of the image."""
def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15,
interp=cv2.INTER_LINEAR):
"""
Args:
xrange (tuple): (min, max) range of scaling ratio for w, e.g. (0.9, 1.2)
yrange (tuple): (min, max) range of scaling ratio for h
xrange (tuple): a (min, max) tuple. If is floating point, the
tuple defines the range of scaling ratio of new width, e.g. (0.9, 1.2).
If is integer, the tuple defines the range of new width in pixels, e.g. (200, 350).
yrange (tuple): similar to xrange, but for height.
minimum (tuple): (xmin, ymin) in pixels. To avoid scaling down too much.
aspect_ratio_thres (float): discard samples which change aspect ratio
larger than this threshold. Set to 0 to keep aspect ratio.
......@@ -144,26 +146,47 @@ class RandomResize(ImageAugmentor):
assert xrange == yrange
self._init(locals())
def is_float(tp):
return isinstance(tp[0], float) or isinstance(tp[1], float)
assert is_float(xrange) == is_float(yrange), "xrange and yrange has different type!"
self._is_scale = is_float(xrange)
def _get_augment_params(self, img):
cnt = 0
h, w = img.shape[:2]
while True:
sx = self._rand_range(*self.xrange)
if self.aspect_ratio_thres == 0:
sy = sx
def get_dest_size():
if self._is_scale:
sx = self._rand_range(*self.xrange)
if self.aspect_ratio_thres == 0:
sy = sx
else:
sy = self._rand_range(*self.yrange)
destX = max(sx * w, self.minimum[0])
destY = max(sy * h, self.minimum[1])
else:
sy = self._rand_range(*self.yrange)
destX = max(sx * w, self.minimum[0])
destY = max(sy * h, self.minimum[1])
oldr = w * 1.0 / h
newr = destX * 1.0 / destY
diff = abs(newr - oldr) / oldr
if diff <= self.aspect_ratio_thres + 1e-5:
return (h, w, int(destY), int(destX))
cnt += 1
if cnt > 50:
logger.warn("RandomResize failed to augment an image")
return (h, w, h, w)
sx = int(self._rand_range(*self.xrange))
if self.aspect_ratio_thres == 0:
sy = sx * 1.0 / w * h
else:
sy = self._rand_range(*self.yrange)
destX = max(sx, self.minimum[0])
destY = max(sy, self.minimum[1])
return (destX, destY)
while True:
destX, destY = get_dest_size()
if self.aspect_ratio_thres > 0: # don't check when thres == 0
oldr = w * 1.0 / h
newr = destX * 1.0 / destY
diff = abs(newr - oldr) / oldr
if diff >= self.aspect_ratio_thres + 1e-5:
cnt += 1
if cnt > 50:
logger.warn("RandomResize failed to augment an image")
return (h, w, h, w)
continue
return (h, w, int(destY), int(destX))
def _augment(self, img, param):
_, _, newh, neww = param
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment