Commit 3397e0bd authored by Yuxin Wu's avatar Yuxin Wu

fix when opencv reshapes 1-channel image (fix #184)

parent 6cce6e01
...@@ -37,6 +37,8 @@ class Shift(ImageAugmentor): ...@@ -37,6 +37,8 @@ class Shift(ImageAugmentor):
def _augment(self, img, shift_m): def _augment(self, img, shift_m):
ret = cv2.warpAffine(img, shift_m, img.shape[1::-1], ret = cv2.warpAffine(img, shift_m, img.shape[1::-1],
borderMode=self.border) borderMode=self.border)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret return ret
...@@ -71,6 +73,8 @@ class Rotation(ImageAugmentor): ...@@ -71,6 +73,8 @@ class Rotation(ImageAugmentor):
def _augment(self, img, rot_m): def _augment(self, img, rot_m):
ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], ret = cv2.warpAffine(img, rot_m, img.shape[1::-1],
flags=self.interp, borderMode=self.border) flags=self.interp, borderMode=self.border)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret return ret
...@@ -99,6 +103,8 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -99,6 +103,8 @@ class RotationAndCropValid(ImageAugmentor):
rot_m = cv2.getRotationMatrix2D((center[0] - 0.5, center[1] - 0.5), deg, 1) rot_m = cv2.getRotationMatrix2D((center[0] - 0.5, center[1] - 0.5), deg, 1)
ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], ret = cv2.warpAffine(img, rot_m, img.shape[1::-1],
flags=self.interp, borderMode=cv2.BORDER_CONSTANT) flags=self.interp, borderMode=cv2.BORDER_CONSTANT)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg) neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg)
neww = min(neww, ret.shape[1]) neww = min(neww, ret.shape[1])
newh = min(newh, ret.shape[0]) newh = min(newh, ret.shape[0])
......
...@@ -152,8 +152,10 @@ class Gamma(ImageAugmentor): ...@@ -152,8 +152,10 @@ class Gamma(ImageAugmentor):
old_dtype = img.dtype old_dtype = img.dtype
lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8') lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8')
img = np.clip(img, 0, 255).astype('uint8') img = np.clip(img, 0, 255).astype('uint8')
img = cv2.LUT(img, lut).astype(old_dtype) ret = cv2.LUT(img, lut).astype(old_dtype)
return img if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
class Clip(ImageAugmentor): class Clip(ImageAugmentor):
......
...@@ -39,8 +39,12 @@ class Flip(ImageAugmentor): ...@@ -39,8 +39,12 @@ class Flip(ImageAugmentor):
def _augment(self, img, do): def _augment(self, img, do):
if do: if do:
img = cv2.flip(img, self.code) ret = cv2.flip(img, self.code)
return img if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
else:
ret = img
return ret
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
raise NotImplementedError() raise NotImplementedError()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment