Commit 67cfcdee authored by Yuxin Wu's avatar Yuxin Wu

fix augmentor bug

parent fb217486
...@@ -77,6 +77,9 @@ class AugmentorList(ImageAugmentor): ...@@ -77,6 +77,9 @@ class AugmentorList(ImageAugmentor):
raise RuntimeError("Cannot simply get parameters of a AugmentorList!") raise RuntimeError("Cannot simply get parameters of a AugmentorList!")
def _augment_return_params(self, img): def _augment_return_params(self, img):
assert img.ndim in [2, 3], img.ndim
img = img.astype('float32')
prms = [] prms = []
for a in self.augs: for a in self.augs:
img, prm = a._augment_return_params(img) img, prm = a._augment_return_params(img)
......
...@@ -88,7 +88,6 @@ class GaussianDeform(ImageAugmentor): ...@@ -88,7 +88,6 @@ class GaussianDeform(ImageAugmentor):
def _augment(self, img, v): def _augment(self, img, v):
grid = self.grid + np.dot(self.gws, v) grid = self.grid + np.dot(self.gws, v)
print(grid)
return np_sample(img, grid) return np_sample(img, grid)
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
......
...@@ -54,7 +54,7 @@ class MapImage(ImageAugmentor): ...@@ -54,7 +54,7 @@ class MapImage(ImageAugmentor):
self.func = func self.func = func
def _augment(self, img, _): def _augment(self, img, _):
img = self.func(img) return self.func(img)
class Resize(ImageAugmentor): class Resize(ImageAugmentor):
...@@ -66,6 +66,6 @@ class Resize(ImageAugmentor): ...@@ -66,6 +66,6 @@ class Resize(ImageAugmentor):
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _augment(self, img, _):
img.arr = cv2.resize( return cv2.resize(
img.arr, self.shape[::-1], img, self.shape[::-1],
interpolation=cv2.INTER_CUBIC) interpolation=cv2.INTER_CUBIC)
...@@ -67,7 +67,8 @@ class CenterPaste(ImageAugmentor): ...@@ -67,7 +67,8 @@ class CenterPaste(ImageAugmentor):
w0 = int((self.background_shape[1] - img_shape[1]) * 0.5) w0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img
img = background img = background
if img.coords: return img
raise NotImplementedError()
def _fprop_coord(self, coord, param):
raise NotImplementedError()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment