Commit 6aa8ab20 authored by Yuxin Wu's avatar Yuxin Wu

bug fix in imgaug

parent c4b38010
...@@ -25,6 +25,9 @@ You'll need tcmalloc to avoid large memory consumption: https://github.com/tenso ...@@ -25,6 +25,9 @@ You'll need tcmalloc to avoid large memory consumption: https://github.com/tenso
This config, with (W,A,G)=(1,1,4), can reach 3.1~3.2% error after 150 epochs. This config, with (W,A,G)=(1,1,4), can reach 3.1~3.2% error after 150 epochs.
With the GaussianDeform augmentor, it will reach 2.8~2.9% With the GaussianDeform augmentor, it will reach 2.8~2.9%
(we are not using this augmentor in the paper). (we are not using this augmentor in the paper).
with (W,A,G)=(1,2,4), error is 3.0~3.1%.
with (W,A,G)=(32,32,32), error is about 2.9%.
""" """
BITW = 1 BITW = 1
......
...@@ -61,7 +61,7 @@ class ImageAugmentor(object): ...@@ -61,7 +61,7 @@ class ImageAugmentor(object):
low, high = 0, low low, high = 0, low
if size == None: if size == None:
size = [] size = []
return low + self.rng.rand(*size) * (high - low) return self.rng.uniform(low, high, size)
class AugmentorList(ImageAugmentor): class AugmentorList(ImageAugmentor):
""" """
......
...@@ -20,8 +20,18 @@ class RandomCrop(ImageAugmentor): ...@@ -20,8 +20,18 @@ class RandomCrop(ImageAugmentor):
def _get_augment_params(self, img): def _get_augment_params(self, img):
orig_shape = img.shape orig_shape = img.shape
h0 = self.rng.randint(0, orig_shape[0] - self.crop_shape[0]) assert orig_shape[0] >= self.crop_shape[0] \
w0 = self.rng.randint(0, orig_shape[1] - self.crop_shape[1]) and orig_shape[1] >= self.crop_shape[1], orig_shape
diffh = orig_shape[0] - self.crop_shape[0]
if diffh == 0:
h0 = 0
else:
h0 = self.rng.randint(diffh)
diffw = orig_shape[1] - self.crop_shape[1]
if diffw == 0:
w0 = 0
else:
w0 = self.rng.randint(diffw)
return (h0, w0) return (h0, w0)
def _augment(self, img, param): def _augment(self, img, param):
......
...@@ -14,11 +14,11 @@ class JpegNoise(ImageAugmentor): ...@@ -14,11 +14,11 @@ class JpegNoise(ImageAugmentor):
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
return self._rand_range(*self.quality_range) return self.rng.randint(*self.quality_range)
def _augment(self, img, q): def _augment(self, img, q):
return cv2.imdecode(cv2.imencode('.jpg', img, enc = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, q])[1]
[cv2.IMWRITE_JPEG_QUALITY, q])[1], 1) return cv2.imdecode(enc, 1)
class GaussianNoise(ImageAugmentor): class GaussianNoise(ImageAugmentor):
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from .base import ImageAugmentor from .base import ImageAugmentor
from ...utils import logger
import numpy as np import numpy as np
import cv2 import cv2
...@@ -68,6 +69,7 @@ class RandomResize(ImageAugmentor): ...@@ -68,6 +69,7 @@ class RandomResize(ImageAugmentor):
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, img):
cnt = 0
while True: while True:
sx = self._rand_range(*self.xrange) sx = self._rand_range(*self.xrange)
sy = self._rand_range(*self.yrange) sy = self._rand_range(*self.yrange)
...@@ -78,6 +80,9 @@ class RandomResize(ImageAugmentor): ...@@ -78,6 +80,9 @@ class RandomResize(ImageAugmentor):
diff = abs(newr - oldr) / oldr diff = abs(newr - oldr) / oldr
if diff <= self.aspect_ratio_thres: if diff <= self.aspect_ratio_thres:
return (destX, destY) return (destX, destY)
cnt += 1
if cnt > 50:
logger.warn("RandomResize failed to augment an image")
def _augment(self, img, dsize): def _augment(self, img, dsize):
return cv2.resize(img, dsize, interpolation=cv2.INTER_CUBIC) return cv2.resize(img, dsize, interpolation=cv2.INTER_CUBIC)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment