Commit 6dfbf89b authored by Yuxin Wu's avatar Yuxin Wu

augmentation

parent 5bf3bc81
...@@ -28,8 +28,8 @@ class BSDS500(RNGDataFlow): ...@@ -28,8 +28,8 @@ class BSDS500(RNGDataFlow):
Produce (image, label) pair, where image has shape (321, 481, 3) and Produce (image, label) pair, where image has shape (321, 481, 3) and
ranges in [0,255]. Label is binary and has shape (321, 481). ranges in [0,255]. Label is binary and has shape (321, 481).
Those pixels annotated as boundaries by >= 3 annotators are Those pixels annotated as boundaries by <=2 annotators are set to 0.
considered positive examples. This is used in `Holistically-Nested Edge Detection This is used in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_. <http://arxiv.org/abs/1504.06375>`_.
""" """
...@@ -73,8 +73,9 @@ class BSDS500(RNGDataFlow): ...@@ -73,8 +73,9 @@ class BSDS500(RNGDataFlow):
gt = loadmat(gt_file)['groundTruth'][0] gt = loadmat(gt_file)['groundTruth'][0]
n_annot = gt.shape[0] n_annot = gt.shape[0]
gt = sum(gt[k]['Boundaries'][0][0] for k in range(n_annot)) gt = sum(gt[k]['Boundaries'][0][0] for k in range(n_annot))
gt[gt > 3] = 3 gt[gt <= 2] = 0
gt = gt / 3.0 gt = gt.astype('float32')
gt /= np.max(gt)
if gt.shape[0] > gt.shape[1]: if gt.shape[0] > gt.shape[1]:
gt = gt.transpose() gt = gt.transpose()
assert gt.shape == (IMG_H, IMG_W) assert gt.shape == (IMG_H, IMG_W)
......
...@@ -9,7 +9,7 @@ from .base import DataFlow, ProxyDataFlow ...@@ -9,7 +9,7 @@ from .base import DataFlow, ProxyDataFlow
from .common import MapDataComponent, MapData from .common import MapDataComponent, MapData
from .imgaug import AugmentorList from .imgaug import AugmentorList
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImagesTogether'] __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
class ImageFromFile(DataFlow): class ImageFromFile(DataFlow):
""" Generate rgb images from list of files """ """ Generate rgb images from list of files """
...@@ -45,7 +45,7 @@ class AugmentImageComponent(MapDataComponent): ...@@ -45,7 +45,7 @@ class AugmentImageComponent(MapDataComponent):
""" """
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order. :param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: the index of the image component in the produced datapoints by `ds`. default to be 0 :param index: the index (or list of indices) of the image component in the produced datapoints by `ds`. default to be 0
""" """
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
super(AugmentImageComponent, self).__init__( super(AugmentImageComponent, self).__init__(
...@@ -56,7 +56,7 @@ class AugmentImageComponent(MapDataComponent): ...@@ -56,7 +56,7 @@ class AugmentImageComponent(MapDataComponent):
self.augs.reset_state() self.augs.reset_state()
class AugmentImagesTogether(MapData): class AugmentImageComponents(MapData):
""" Augment a list of images of the same shape, with the same parameters""" """ Augment a list of images of the same shape, with the same parameters"""
def __init__(self, ds, augmentors, index=(0,1)): def __init__(self, ds, augmentors, index=(0,1)):
""" """
...@@ -75,7 +75,7 @@ class AugmentImagesTogether(MapData): ...@@ -75,7 +75,7 @@ class AugmentImagesTogether(MapData):
dp[idx] = self.augs._augment(dp[idx], prms) dp[idx] = self.augs._augment(dp[idx], prms)
return dp return dp
super(AugmentImagesTogether, self).__init__(ds, func) super(AugmentImageComponents, self).__init__(ds, func)
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ImageAugmentor from .base import ImageAugmentor
import math
import cv2 import cv2
import numpy as np import numpy as np
__all__ = ['Rotation'] __all__ = ['Rotation', 'RotationAndCropValid']
class Rotation(ImageAugmentor): class Rotation(ImageAugmentor):
""" Random rotate the image w.r.t a random center""" """ Random rotate the image w.r.t a random center"""
...@@ -31,3 +32,51 @@ class Rotation(ImageAugmentor): ...@@ -31,3 +32,51 @@ class Rotation(ImageAugmentor):
flags=self.interp, borderMode=self.border) flags=self.interp, borderMode=self.border)
return ret return ret
class RotationAndCropValid(ImageAugmentor):
""" Random rotate and crop the largest possible rect without the border
This will produce images of different shapes.
"""
def __init__(self, max_deg, interp=cv2.INTER_CUBIC):
self._init(locals())
def _get_augment_params(self, img):
deg = self._rand_range(-self.max_deg, self.max_deg)
return deg
def _augment(self, img, deg):
center = (img.shape[1]*0.5, img.shape[0]*0.5)
rot_m = cv2.getRotationMatrix2D(center, deg, 1)
ret = cv2.warpAffine(img, rot_m, img.shape[1::-1],
flags=self.interp, borderMode=cv2.BORDER_CONSTANT)
neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg)
neww = min(neww, ret.shape[1])
newh = min(newh, ret.shape[0])
newx = center[0] - neww * 0.5
newy = center[1] - newh * 0.5
#print(ret.shape, deg, newx, newy, neww, newh)
return ret[newy:newy+newh,newx:newx+neww]
@staticmethod
def largest_rotated_rect(w, h, angle):
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """
angle = angle / 180.0 * math.pi
if w <= 0 or h <= 0:
return 0,0
width_is_longer = w >= h
side_long, side_short = (w,h) if width_is_longer else (h,w)
# since the solutions for angle, -angle and 180-angle are all the same,
# if suffices to look at the first quadrant and the absolute values of sin,cos:
sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle))
if side_short <= 2.*sin_a*cos_a*side_long:
# half constrained case: two crop corners touch the longer side,
# the other two corners are on the mid-line parallel to the longer line
x = 0.5*side_short
wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a)
else:
# fully constrained case: crop touches all 4 sides
cos_2a = cos_a*cos_a - sin_a*sin_a
wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a
return wr,hr
...@@ -6,7 +6,8 @@ from .base import ImageAugmentor ...@@ -6,7 +6,8 @@ from .base import ImageAugmentor
import numpy as np import numpy as np
import cv2 import cv2
__all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur', 'Gamma'] __all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur',
'Gamma', 'Clip']
class Brightness(ImageAugmentor): class Brightness(ImageAugmentor):
""" """
...@@ -102,3 +103,11 @@ class Gamma(ImageAugmentor): ...@@ -102,3 +103,11 @@ class Gamma(ImageAugmentor):
img = cv2.LUT(img, lut).astype('float32') img = cv2.LUT(img, lut).astype('float32')
return img return img
class Clip(ImageAugmentor):
def __init__(self, min=0, max=255):
assert delta > 0
self._init(locals())
def _augment(self, img, _):
img = np.clip(img, self.min, self.max)
return img
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment