Commit c9fde630 authored by Yuxin Wu's avatar Yuxin Wu Committed by GitHub

Make augmentors return a `Transform` instance. (#1290)

* pure-image can run

* external; tests

* keep backward compatibility

* Remove the need for "augment_return_xxx" by using `LazyTransform`.

* better printing and docs

* Let PhotometricAugmentor's interface similar to the old one

* update docs

* udpate examples

* update tests & warnings
parent 5af84e93
...@@ -376,6 +376,7 @@ _DEPRECATED_NAMES = set([ ...@@ -376,6 +376,7 @@ _DEPRECATED_NAMES = set([
'DistributedTrainerParameterServer', 'DistributedTrainerParameterServer',
'InputDesc', 'InputDesc',
'inputs_desc', 'inputs_desc',
'Augmentor',
# renamed items that should not appear in docs # renamed items that should not appear in docs
'DumpTensor', 'DumpTensor',
......
...@@ -23,22 +23,7 @@ as the DataFlow. ...@@ -23,22 +23,7 @@ as the DataFlow.
In other words, for simple mapping you do not need to write an augmentor. In other words, for simple mapping you do not need to write an augmentor.
An augmentor may do something more than just applying a mapping. An augmentor may do something more than just applying a mapping.
To do complicated augmentation, the interface you will need to implement is: To do custom augmentation, you can implement one yourself.
```python
class MyAug(imgaug.ImageAugmentor):
def _get_augment_params(self, img):
# Generated random params with self.rng
return params
def _augment(self, img, params):
return augmented_img
# optional method
def _augment_coords(self, coords, param):
# coords is a Nx2 floating point array, each row is (x, y)
return augmented_coords
```
#### The Design of imgaug Module #### The Design of imgaug Module
...@@ -46,29 +31,59 @@ class MyAug(imgaug.ImageAugmentor): ...@@ -46,29 +31,59 @@ class MyAug(imgaug.ImageAugmentor):
The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the following usage: The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the following usage:
* Factor out randomness and determinism. * Factor out randomness and determinism.
An augmentor may be randomized, but you can call An augmentor often contains randomized policy, e.g., it randomly perturbs each image differently.
[augment_return_params](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.augment_return_params) However, its "deterministic" part needs to be factored out, so that
to obtain the randomized parameters and then call the same transformation can be re-applied to other data
[augment_with_params](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.augment_with_params) assocaited with the image. This is achieved like this:
on other data with the same randomized parameters.
```python
* Because of the above reason, tensorpack's augmentor can augment multiple images together tfm = augmentor.get_transform(img) # a deterministic transformation
easily. This is commonly used for augmenting an image together with its masks. new_img = tfm.apply_image(img)
new_img2 = tfm.apply_image(img2)
* An image augmentor (e.g. flip) may also augment a coordinate, with new_coords = tfm.apply_coords(coords)
[augment_coords](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.ImageAugmentor.augment_coords). ```
In this way, images can be augmented together with
boxes, polygons, keypoints, etc. Due to this design, it can augment images together with its annotations
Coordinate augmentation enforces floating points coordinates (e.g., segmentation masks, bounding boxes, keypoints).
Our coordinate augmentation enforces floating points coordinates
to avoid quantization error. to avoid quantization error.
When you don't need to re-apply the same transformation, you can also just call
```python
new_img = augmentor.augment(img)
```
* Reset random seed. Random seed can be reset by * Reset random seed. Random seed can be reset by
[reset_state](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.reset_state). [reset_state](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.ImageAugmentor.reset_state).
This is important for multi-process data loading, to make sure different This is important for multi-process data loading, to make sure different
processes get different seeds. processes get different seeds.
The reset method is called automatically if you use tensorpack's The reset method is called automatically if you use tensorpack's
[image augmentation dataflow](../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent). [image augmentation dataflow](../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent)
or if you use Python 3.7+.
Otherwise, **you are responsible** for calling it by yourself in subprocesses. Otherwise, **you are responsible** for calling it by yourself in subprocesses.
See the See the
[API documentation](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.reset_state) [API documentation](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.ImageAugmentor.reset_state)
of this method for more details. of this method for more details.
### Write an Augmentor
The interface you will need to implement is:
```python
class MyAug(imgaug.ImageAugmentor):
def get_transform(self, img):
# Randomly generate a deterministic transformation, to be applied on img
x = random_parameters()
return MyTransform(x)
class MyTransform(imgaug.Transform):
def apply_image(self, img):
return new_img
def apply_coords(self, coords):
return new_coords
```
Check out the zoo of builtin augmentors to have a better sense.
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import cv2 import cv2
from tensorpack.dataflow import RNGDataFlow from tensorpack.dataflow import RNGDataFlow
from tensorpack.dataflow.imgaug import transform from tensorpack.dataflow.imgaug import ImageAugmentor, ResizeTransform
class DataFromListOfDict(RNGDataFlow): class DataFromListOfDict(RNGDataFlow):
...@@ -26,7 +26,7 @@ class DataFromListOfDict(RNGDataFlow): ...@@ -26,7 +26,7 @@ class DataFromListOfDict(RNGDataFlow):
yield dp yield dp
class CustomResize(transform.TransformAugmentorBase): class CustomResize(ImageAugmentor):
""" """
Try resizing the shortest edge to a certain number Try resizing the shortest edge to a certain number
while avoiding the longest edge to exceed max_size. while avoiding the longest edge to exceed max_size.
...@@ -44,7 +44,7 @@ class CustomResize(transform.TransformAugmentorBase): ...@@ -44,7 +44,7 @@ class CustomResize(transform.TransformAugmentorBase):
short_edge_length = (short_edge_length, short_edge_length) short_edge_length = (short_edge_length, short_edge_length)
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def get_transform(self, img):
h, w = img.shape[:2] h, w = img.shape[:2]
size = self.rng.randint( size = self.rng.randint(
self.short_edge_length[0], self.short_edge_length[1] + 1) self.short_edge_length[0], self.short_edge_length[1] + 1)
...@@ -59,7 +59,7 @@ class CustomResize(transform.TransformAugmentorBase): ...@@ -59,7 +59,7 @@ class CustomResize(transform.TransformAugmentorBase):
neww = neww * scale neww = neww * scale
neww = int(neww + 0.5) neww = int(neww + 0.5)
newh = int(newh + 0.5) newh = int(newh + 0.5)
return transform.ResizeTransform(h, w, newh, neww, self.interp) return ResizeTransform(h, w, newh, neww, self.interp)
def box_to_point8(boxes): def box_to_point8(boxes):
......
...@@ -90,9 +90,10 @@ class TrainingDataPreprocessor: ...@@ -90,9 +90,10 @@ class TrainingDataPreprocessor:
boxes[:, 1::2] *= height boxes[:, 1::2] *= height
# augmentation: # augmentation:
im, params = self.aug.augment_return_params(im) tfms = self.aug.get_transform(im)
im = tfms.apply_image(im)
points = box_to_point8(boxes) points = box_to_point8(boxes)
points = self.aug.augment_coords(points, params) points = tfms.apply_coords(points)
boxes = point8_to_box(points) boxes = point8_to_box(points)
if len(boxes): if len(boxes):
assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!" assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!"
...@@ -131,7 +132,7 @@ class TrainingDataPreprocessor: ...@@ -131,7 +132,7 @@ class TrainingDataPreprocessor:
for polys in segmentation: for polys in segmentation:
if not self.cfg.DATA.ABSOLUTE_COORD: if not self.cfg.DATA.ABSOLUTE_COORD:
polys = [p * width_height for p in polys] polys = [p * width_height for p in polys]
polys = [self.aug.augment_coords(p, params) for p in polys] polys = [tfms.apply_coords(p) for p in polys]
masks.append(segmentation_to_mask(polys, im.shape[0], gt_mask_width)) masks.append(segmentation_to_mask(polys, im.shape[0], gt_mask_width))
if len(masks): if len(masks):
......
...@@ -191,7 +191,7 @@ def get_data(name): ...@@ -191,7 +191,7 @@ def get_data(name):
ds = dataset.BSDS500(name, shuffle=True) ds = dataset.BSDS500(name, shuffle=True)
class CropMultiple16(imgaug.ImageAugmentor): class CropMultiple16(imgaug.ImageAugmentor):
def _get_augment_params(self, img): def get_transform(self, img):
newh = img.shape[0] // 16 * 16 newh = img.shape[0] // 16 * 16
neww = img.shape[1] // 16 * 16 neww = img.shape[1] // 16 * 16
assert newh > 0 and neww > 0 assert newh > 0 and neww > 0
...@@ -199,11 +199,7 @@ def get_data(name): ...@@ -199,11 +199,7 @@ def get_data(name):
h0 = 0 if diffh == 0 else self.rng.randint(diffh) h0 = 0 if diffh == 0 else self.rng.randint(diffh)
diffw = img.shape[1] - neww diffw = img.shape[1] - neww
w0 = 0 if diffw == 0 else self.rng.randint(diffw) w0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (h0, w0, newh, neww) return imgaug.CropTransform(h0, w0, newh, neww)
def _augment(self, img, param):
h0, w0, newh, neww = param
return img[h0:h0 + newh, w0:w0 + neww]
if isTrain: if isTrain:
shape_aug = [ shape_aug = [
......
...@@ -161,10 +161,9 @@ class AugmentImageCoordinates(MapData): ...@@ -161,10 +161,9 @@ class AugmentImageCoordinates(MapData):
validate_coords(coords) validate_coords(coords)
if self._copy: if self._copy:
img, coords = copy_mod.deepcopy((img, coords)) img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs.augment_return_params(img) tfms = self.augs.get_transform(img)
dp[self._img_index] = img dp[self._img_index] = tfms.apply_image(img)
coords = self.augs.augment_coords(coords, prms) dp[self._coords_index] = tfms.apply_coords(coords)
dp[self._coords_index] = coords
return dp return dp
...@@ -207,15 +206,15 @@ class AugmentImageComponents(MapData): ...@@ -207,15 +206,15 @@ class AugmentImageComponents(MapData):
major_image = index[0] # image to be used to get params. TODO better design? major_image = index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image]) im = copy_func(dp[major_image])
check_dtype(im) check_dtype(im)
im, prms = self.augs.augment_return_params(im) tfms = self.augs.get_transform(im)
dp[major_image] = im dp[major_image] = tfms.apply_image(im)
for idx in index[1:]: for idx in index[1:]:
check_dtype(dp[idx]) check_dtype(dp[idx])
dp[idx] = self.augs.augment_with_params(copy_func(dp[idx]), prms) dp[idx] = tfms.apply_image(copy_func(dp[idx]))
for idx in coords_index: for idx in coords_index:
coords = copy_func(dp[idx]) coords = copy_func(dp[idx])
validate_coords(coords) validate_coords(coords)
dp[idx] = self.augs.augment_coords(coords, prms) dp[idx] = tfms.apply_coords(coords)
return dp return dp
super(AugmentImageComponents, self).__init__(ds, func) super(AugmentImageComponents, self).__init__(ds, func)
......
...@@ -3,29 +3,170 @@ ...@@ -3,29 +3,170 @@
import sys import sys
import numpy as np
import cv2 import cv2
import unittest
from . import AugmentorList from .base import ImageAugmentor, AugmentorList
from .crop import * from .imgproc import Contrast
from .deform import *
from .imgproc import *
from .noise import SaltPepperNoise from .noise import SaltPepperNoise
from .noname import * from .misc import Flip, Resize
anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([ def _rand_image(shape=(20, 20)):
Contrast((0.8, 1.2)), return np.random.rand(*shape).astype("float32")
Flip(horiz=True),
GaussianDeform(anchors, (360, 480), 0.2, randrange=20),
# RandomCropRandomShape(0.3), class LegacyBrightness(ImageAugmentor):
SaltPepperNoise() def __init__(self, delta, clip=True):
]) super(LegacyBrightness, self).__init__()
assert delta > 0
img = cv2.imread(sys.argv[1]) self._init(locals())
newimg, prms = augmentors._augment_return_params(img)
cv2.imshow(" ", newimg.astype('uint8')) def _get_augment_params(self, _):
cv2.waitKey() v = self._rand_range(-self.delta, self.delta)
return v
newimg = augmentors._augment(img, prms)
cv2.imshow(" ", newimg.astype('uint8')) def _augment(self, img, v):
cv2.waitKey() old_dtype = img.dtype
img = img.astype('float32')
img += v
if self.clip or old_dtype == np.uint8:
img = np.clip(img, 0, 255)
return img.astype(old_dtype)
class LegacyFlip(ImageAugmentor):
def __init__(self, horiz=False, vert=False, prob=0.5):
super(LegacyFlip, self).__init__()
if horiz and vert:
raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
elif horiz:
self.code = 1
elif vert:
self.code = 0
else:
raise ValueError("At least one of horiz or vert has to be True!")
self._init(locals())
def _get_augment_params(self, img):
h, w = img.shape[:2]
do = self._rand_range() < self.prob
return (do, h, w)
def _augment(self, img, param):
do, _, _ = param
if do:
ret = cv2.flip(img, self.code)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
else:
ret = img
return ret
def _augment_coords(self, coords, param):
do, h, w = param
if do:
if self.code == 0:
coords[:, 1] = h - coords[:, 1]
elif self.code == 1:
coords[:, 0] = w - coords[:, 0]
return coords
class ImgAugTest(unittest.TestCase):
def _get_augs(self):
return AugmentorList([
Contrast((0.8, 1.2)),
Flip(horiz=True),
Resize((30, 30)),
SaltPepperNoise()
])
def _get_augs_with_legacy(self):
return AugmentorList([
LegacyBrightness(0.5),
LegacyFlip(horiz=True),
Resize((30, 30)),
SaltPepperNoise()
])
def test_augmentors(self):
augmentors = self._get_augs()
img = _rand_image()
orig = img.copy()
tfms = augmentors.get_transform(img)
# test printing
print(augmentors)
print(tfms)
newimg = tfms.apply_image(img)
print(tfms) # lazy ones will instantiate after the first apply
newimg2 = tfms.apply_image(orig)
self.assertTrue(np.allclose(newimg, newimg2))
self.assertEqual(newimg2.shape[0], 30)
coords = np.asarray([[0, 0], [10, 12]], dtype="float32")
tfms.apply_coords(coords)
def test_legacy_usage(self):
augmentors = self._get_augs()
img = _rand_image()
orig = img.copy()
newimg, tfms = augmentors.augment_return_params(img)
newimg2 = augmentors.augment_with_params(orig, tfms)
self.assertTrue(np.allclose(newimg, newimg2))
self.assertEqual(newimg2.shape[0], 30)
coords = np.asarray([[0, 0], [10, 12]], dtype="float32")
augmentors.augment_coords(coords, tfms)
def test_legacy_augs_new_usage(self):
augmentors = self._get_augs_with_legacy()
img = _rand_image()
orig = img.copy()
tfms = augmentors.get_transform(img)
newimg = tfms.apply_image(img)
newimg2 = tfms.apply_image(orig)
self.assertTrue(np.allclose(newimg, newimg2))
self.assertEqual(newimg2.shape[0], 30)
coords = np.asarray([[0, 0], [10, 12]], dtype="float32")
tfms.apply_coords(coords)
def test_legacy_augs_legacy_usage(self):
augmentors = self._get_augs_with_legacy()
img = _rand_image()
orig = img.copy()
newimg, tfms = augmentors.augment_return_params(img)
newimg2 = augmentors.augment_with_params(orig, tfms)
self.assertTrue(np.allclose(newimg, newimg2))
self.assertEqual(newimg2.shape[0], 30)
coords = np.asarray([[0, 0], [10, 12]], dtype="float32")
augmentors.augment_coords(coords, tfms)
if __name__ == '__main__':
anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([
Contrast((0.8, 1.2)),
Flip(horiz=True),
# RandomCropRandomShape(0.3),
SaltPepperNoise()
])
img = cv2.imread(sys.argv[1])
newimg, prms = augmentors._augment_return_params(img)
cv2.imshow(" ", newimg.astype('uint8'))
cv2.waitKey()
newimg = augmentors._augment(img, prms)
cv2.imshow(" ", newimg.astype('uint8'))
cv2.waitKey()
This diff is collapsed.
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
import numpy as np import numpy as np
import cv2 import cv2
from .base import ImageAugmentor from .base import PhotometricAugmentor
from .meta import MapImage
__all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32'] __all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']
class ColorSpace(ImageAugmentor): class ColorSpace(PhotometricAugmentor):
""" Convert into another color space. """ """ Convert into another color space. """
def __init__(self, mode, keepdims=True): def __init__(self, mode, keepdims=True):
...@@ -20,6 +19,7 @@ class ColorSpace(ImageAugmentor): ...@@ -20,6 +19,7 @@ class ColorSpace(ImageAugmentor):
keepdims (bool): keep the dimension of image unchanged if OpenCV keepdims (bool): keep the dimension of image unchanged if OpenCV
changes it. changes it.
""" """
super(ColorSpace, self).__init__()
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _augment(self, img, _):
...@@ -43,13 +43,13 @@ class Grayscale(ColorSpace): ...@@ -43,13 +43,13 @@ class Grayscale(ColorSpace):
super(Grayscale, self).__init__(mode, keepdims) super(Grayscale, self).__init__(mode, keepdims)
class ToUint8(MapImage): class ToUint8(PhotometricAugmentor):
""" Convert image to uint8. Useful to reduce communication overhead. """ """ Convert image to uint8. Useful to reduce communication overhead. """
def __init__(self): def _augment(self, img, _):
super(ToUint8, self).__init__(lambda x: np.clip(x, 0, 255).astype(np.uint8), lambda x: x) return np.clip(img, 0, 255).astype(np.uint8)
class ToFloat32(MapImage): class ToFloat32(PhotometricAugmentor):
""" Convert image to float32, may increase quality of the augmentor. """ """ Convert image to float32, may increase quality of the augmentor. """
def __init__(self): def _augment(self, img, _):
super(ToFloat32, self).__init__(lambda x: x.astype(np.float32), lambda x: x) return img.astype(np.float32)
...@@ -6,14 +6,14 @@ import cv2 ...@@ -6,14 +6,14 @@ import cv2
from ...utils.argtools import shape2d from ...utils.argtools import shape2d
from ...utils.develop import log_deprecated from ...utils.develop import log_deprecated
from .base import ImageAugmentor from .base import ImageAugmentor, ImagePlaceholder
from .transform import CropTransform, TransformAugmentorBase from .transform import CropTransform, TransformList, ResizeTransform
from .misc import ResizeShortestEdge from .misc import ResizeShortestEdge
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape', 'GoogleNetRandomCropAndResize'] __all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape', 'GoogleNetRandomCropAndResize']
class RandomCrop(TransformAugmentorBase): class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """ """ Randomly crop the image into a smaller one """
def __init__(self, crop_shape): def __init__(self, crop_shape):
...@@ -25,7 +25,7 @@ class RandomCrop(TransformAugmentorBase): ...@@ -25,7 +25,7 @@ class RandomCrop(TransformAugmentorBase):
super(RandomCrop, self).__init__() super(RandomCrop, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def get_transform(self, img):
orig_shape = img.shape orig_shape = img.shape
assert orig_shape[0] >= self.crop_shape[0] \ assert orig_shape[0] >= self.crop_shape[0] \
and orig_shape[1] >= self.crop_shape[1], orig_shape and orig_shape[1] >= self.crop_shape[1], orig_shape
...@@ -36,7 +36,7 @@ class RandomCrop(TransformAugmentorBase): ...@@ -36,7 +36,7 @@ class RandomCrop(TransformAugmentorBase):
return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1]) return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1])
class CenterCrop(TransformAugmentorBase): class CenterCrop(ImageAugmentor):
""" Crop the image at the center""" """ Crop the image at the center"""
def __init__(self, crop_shape): def __init__(self, crop_shape):
...@@ -47,7 +47,7 @@ class CenterCrop(TransformAugmentorBase): ...@@ -47,7 +47,7 @@ class CenterCrop(TransformAugmentorBase):
crop_shape = shape2d(crop_shape) crop_shape = shape2d(crop_shape)
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def get_transform(self, img):
orig_shape = img.shape orig_shape = img.shape
assert orig_shape[0] >= self.crop_shape[0] \ assert orig_shape[0] >= self.crop_shape[0] \
and orig_shape[1] >= self.crop_shape[1], orig_shape and orig_shape[1] >= self.crop_shape[1], orig_shape
...@@ -56,7 +56,7 @@ class CenterCrop(TransformAugmentorBase): ...@@ -56,7 +56,7 @@ class CenterCrop(TransformAugmentorBase):
return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1]) return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1])
class RandomCropRandomShape(TransformAugmentorBase): class RandomCropRandomShape(ImageAugmentor):
""" Random crop with a random shape""" """ Random crop with a random shape"""
def __init__(self, wmin, hmin, def __init__(self, wmin, hmin,
...@@ -70,18 +70,19 @@ class RandomCropRandomShape(TransformAugmentorBase): ...@@ -70,18 +70,19 @@ class RandomCropRandomShape(TransformAugmentorBase):
wmin, hmin, wmax, hmax: range to sample shape. wmin, hmin, wmax, hmax: range to sample shape.
max_aspect_ratio (float): this argument has no effect and is deprecated. max_aspect_ratio (float): this argument has no effect and is deprecated.
""" """
super(RandomCropRandomShape, self).__init__()
if max_aspect_ratio is not None: if max_aspect_ratio is not None:
log_deprecated("RandomCropRandomShape(max_aspect_ratio)", "It is never implemented!", "2020-06-06") log_deprecated("RandomCropRandomShape(max_aspect_ratio)", "It is never implemented!", "2020-06-06")
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def get_transform(self, img):
hmax = self.hmax or img.shape[0] hmax = self.hmax or img.shape[0]
wmax = self.wmax or img.shape[1] wmax = self.wmax or img.shape[1]
h = self.rng.randint(self.hmin, hmax + 1) h = self.rng.randint(self.hmin, hmax + 1)
w = self.rng.randint(self.wmin, wmax + 1) w = self.rng.randint(self.wmin, wmax + 1)
diffh = img.shape[0] - h diffh = img.shape[0] - h
diffw = img.shape[1] - w diffw = img.shape[1] - w
assert diffh >= 0 and diffw >= 0 assert diffh >= 0 and diffw >= 0, str(diffh) + ", " + str(diffw)
y0 = 0 if diffh == 0 else self.rng.randint(diffh) y0 = 0 if diffh == 0 else self.rng.randint(diffh)
x0 = 0 if diffw == 0 else self.rng.randint(diffw) x0 = 0 if diffw == 0 else self.rng.randint(diffw)
return CropTransform(y0, x0, h, w) return CropTransform(y0, x0, h, w)
...@@ -106,9 +107,10 @@ class GoogleNetRandomCropAndResize(ImageAugmentor): ...@@ -106,9 +107,10 @@ class GoogleNetRandomCropAndResize(ImageAugmentor):
aspect_ratio_range (tuple(float)): Defaults to make aspect ratio in 3/4-4/3. aspect_ratio_range (tuple(float)): Defaults to make aspect ratio in 3/4-4/3.
target_shape (int): Defaults to 224, the standard ImageNet image shape. target_shape (int): Defaults to 224, the standard ImageNet image shape.
""" """
super(GoogleNetRandomCropAndResize, self).__init__()
self._init(locals()) self._init(locals())
def _augment(self, img, _): def get_transform(self, img):
h, w = img.shape[:2] h, w = img.shape[:2]
area = h * w area = h * w
for _ in range(10): for _ in range(10):
...@@ -121,12 +123,11 @@ class GoogleNetRandomCropAndResize(ImageAugmentor): ...@@ -121,12 +123,11 @@ class GoogleNetRandomCropAndResize(ImageAugmentor):
if hh <= h and ww <= w: if hh <= h and ww <= w:
x1 = 0 if w == ww else self.rng.randint(0, w - ww) x1 = 0 if w == ww else self.rng.randint(0, w - ww)
y1 = 0 if h == hh else self.rng.randint(0, h - hh) y1 = 0 if h == hh else self.rng.randint(0, h - hh)
out = img[y1:y1 + hh, x1:x1 + ww] return TransformList([
out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=self.interp) CropTransform(y1, x1, hh, ww),
return out ResizeTransform(hh, ww, self.target_shape, self.target_shape, interp=self.interp)
out = ResizeShortestEdge(self.target_shape, interp=self.interp).augment(img) ])
out = CenterCrop(self.target_shape).augment(out) tfm1 = ResizeShortestEdge(self.target_shape, interp=self.interp).get_transform(img)
return out out_shape = (tfm1.new_h, tfm1.new_w)
tfm2 = CenterCrop(self.target_shape).get_transform(ImagePlaceholder(shape=out_shape))
def _augment_coords(self, coords, param): return TransformList([tfm1, tfm2])
raise NotImplementedError()
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
from ...utils import logger from ...utils import logger
from .base import ImageAugmentor from .base import ImageAugmentor
from .transform import TransformFactory
__all__ = [] __all__ = []
...@@ -97,14 +98,11 @@ class GaussianDeform(ImageAugmentor): ...@@ -97,14 +98,11 @@ class GaussianDeform(ImageAugmentor):
self.randrange = randrange self.randrange = randrange
self.sigma = sigma self.sigma = sigma
def _get_augment_params(self, img): def get_transform(self, img):
v = self.rng.rand(self.K, 2).astype('float32') - 0.5 v = self.rng.rand(self.K, 2).astype('float32') - 0.5
v = v * 2 * self.randrange v = v * 2 * self.randrange
return v return TransformFactory(name=str(self), apply_image=lambda img: self._augment(img, v))
def _augment(self, img, v): def _augment(self, img, v):
grid = self.grid + np.dot(self.gws, v) grid = self.grid + np.dot(self.gws, v)
return np_sample(img, grid) return np_sample(img, grid)
def _augment_coords(self, coords, param):
raise NotImplementedError()
...@@ -3,10 +3,26 @@ ...@@ -3,10 +3,26 @@
import numpy as np import numpy as np
from .base import ImageAugmentor from .base import ImageAugmentor
from .transform import Transform
__all__ = ['IAAugmentor', 'Albumentations'] __all__ = ['IAAugmentor', 'Albumentations']
class IAATransform(Transform):
def __init__(self, aug, img_shape):
self._init(locals())
def apply_image(self, img):
return self.aug.augment_image(img)
def apply_coords(self, coords):
import imgaug as IA
points = [IA.Keypoint(x=x, y=y) for x, y in coords]
points = IA.KeypointsOnImage(points, shape=self.img_shape)
augmented = self.aug.augment_keypoints([points])[0].keypoints
return np.asarray([[p.x, p.y] for p in augmented])
class IAAugmentor(ImageAugmentor): class IAAugmentor(ImageAugmentor):
""" """
Wrap an augmentor form the IAA library: https://github.com/aleju/imgaug. Wrap an augmentor form the IAA library: https://github.com/aleju/imgaug.
...@@ -43,20 +59,16 @@ class IAAugmentor(ImageAugmentor): ...@@ -43,20 +59,16 @@ class IAAugmentor(ImageAugmentor):
super(IAAugmentor, self).__init__() super(IAAugmentor, self).__init__()
self._aug = augmentor self._aug = augmentor
def _get_augment_params(self, img): def get_transform(self, img):
return (self._aug.to_deterministic(), img.shape) return IAATransform(self._aug.to_deterministic(), img.shape)
def _augment(self, img, param):
aug, _ = param
return aug.augment_image(img)
def _augment_coords(self, coords, param): class AlbumentationsTransform(Transform):
import imgaug as IA def __init__(self, aug, param):
aug, shape = param self._init(locals())
points = [IA.Keypoint(x=x, y=y) for x, y in coords]
points = IA.KeypointsOnImage(points, shape=shape) def apply_image(self, img):
augmented = aug.augment_keypoints([points])[0].keypoints return self.aug.apply(img, **self.param)
return np.asarray([[p.x, p.y] for p in augmented])
class Albumentations(ImageAugmentor): class Albumentations(ImageAugmentor):
...@@ -81,11 +93,5 @@ class Albumentations(ImageAugmentor): ...@@ -81,11 +93,5 @@ class Albumentations(ImageAugmentor):
super(Albumentations, self).__init__() super(Albumentations, self).__init__()
self._aug = augmentor self._aug = augmentor
def _get_augment_params(self, img): def get_transform(self, img):
return self._aug.get_params() return AlbumentationsTransform(self._aug, self._aug.get_params())
def _augment(self, img, param):
return self._aug.apply(img, **param)
def _augment_coords(self, coords, param):
raise NotImplementedError()
...@@ -7,12 +7,12 @@ import numpy as np ...@@ -7,12 +7,12 @@ import numpy as np
import cv2 import cv2
from .base import ImageAugmentor from .base import ImageAugmentor
from .transform import TransformAugmentorBase, WarpAffineTransform from .transform import WarpAffineTransform, CropTransform, TransformList
__all__ = ['Shift', 'Rotation', 'RotationAndCropValid', 'Affine'] __all__ = ['Shift', 'Rotation', 'RotationAndCropValid', 'Affine']
class Shift(TransformAugmentorBase): class Shift(ImageAugmentor):
""" Random horizontal and vertical shifts """ """ Random horizontal and vertical shifts """
def __init__(self, horiz_frac=0, vert_frac=0, def __init__(self, horiz_frac=0, vert_frac=0,
...@@ -28,7 +28,7 @@ class Shift(TransformAugmentorBase): ...@@ -28,7 +28,7 @@ class Shift(TransformAugmentorBase):
super(Shift, self).__init__() super(Shift, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def get_transform(self, img):
max_dx = self.horiz_frac * img.shape[1] max_dx = self.horiz_frac * img.shape[1]
max_dy = self.vert_frac * img.shape[0] max_dy = self.vert_frac * img.shape[0]
dx = np.round(self._rand_range(-max_dx, max_dx)) dx = np.round(self._rand_range(-max_dx, max_dx))
...@@ -40,7 +40,7 @@ class Shift(TransformAugmentorBase): ...@@ -40,7 +40,7 @@ class Shift(TransformAugmentorBase):
borderMode=self.border, borderValue=self.border_value) borderMode=self.border, borderValue=self.border_value)
class Rotation(TransformAugmentorBase): class Rotation(ImageAugmentor):
""" Random rotate the image w.r.t a random center""" """ Random rotate the image w.r.t a random center"""
def __init__(self, max_deg, center_range=(0, 1), def __init__(self, max_deg, center_range=(0, 1),
...@@ -61,7 +61,7 @@ class Rotation(TransformAugmentorBase): ...@@ -61,7 +61,7 @@ class Rotation(TransformAugmentorBase):
super(Rotation, self).__init__() super(Rotation, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def get_transform(self, img):
center = img.shape[1::-1] * self._rand_range( center = img.shape[1::-1] * self._rand_range(
self.center_range[0], self.center_range[1], (2,)) self.center_range[0], self.center_range[1], (2,))
deg = self._rand_range(-self.max_deg, self.max_deg) deg = self._rand_range(-self.max_deg, self.max_deg)
...@@ -100,29 +100,27 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -100,29 +100,27 @@ class RotationAndCropValid(ImageAugmentor):
super(RotationAndCropValid, self).__init__() super(RotationAndCropValid, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_deg(self, img):
deg = self._rand_range(-self.max_deg, self.max_deg) deg = self._rand_range(-self.max_deg, self.max_deg)
if self.step_deg: if self.step_deg:
deg = deg // self.step_deg * self.step_deg deg = deg // self.step_deg * self.step_deg
return deg return deg
def _augment(self, img, deg): def get_transform(self, img):
deg = self._get_deg(img)
h, w = img.shape[:2]
center = (img.shape[1] * 0.5, img.shape[0] * 0.5) center = (img.shape[1] * 0.5, img.shape[0] * 0.5)
rot_m = cv2.getRotationMatrix2D((center[0] - 0.5, center[1] - 0.5), deg, 1) rot_m = cv2.getRotationMatrix2D((center[0] - 0.5, center[1] - 0.5), deg, 1)
ret = cv2.warpAffine(img, rot_m, img.shape[1::-1], tfm = WarpAffineTransform(rot_m, (w, h), interp=self.interp)
flags=self.interp, borderMode=cv2.BORDER_CONSTANT)
if img.ndim == 3 and ret.ndim == 2: neww, newh = RotationAndCropValid.largest_rotated_rect(w, h, deg)
ret = ret[:, :, np.newaxis] neww = min(neww, w)
neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg) newh = min(newh, h)
neww = min(neww, ret.shape[1])
newh = min(newh, ret.shape[0])
newx = int(center[0] - neww * 0.5) newx = int(center[0] - neww * 0.5)
newy = int(center[1] - newh * 0.5) newy = int(center[1] - newh * 0.5)
# print(ret.shape, deg, newx, newy, neww, newh) tfm2 = CropTransform(newy, newx, newh, neww)
return ret[newy:newy + newh, newx:newx + neww] return TransformList([tfm, tfm2])
def _augment_coords(self, coords, param):
raise NotImplementedError()
@staticmethod @staticmethod
def largest_rotated_rect(w, h, angle): def largest_rotated_rect(w, h, angle):
...@@ -152,7 +150,7 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -152,7 +150,7 @@ class RotationAndCropValid(ImageAugmentor):
return int(np.round(wr)), int(np.round(hr)) return int(np.round(wr)), int(np.round(hr))
class Affine(TransformAugmentorBase): class Affine(ImageAugmentor):
""" """
Random affine transform of the image w.r.t to the image center. Random affine transform of the image w.r.t to the image center.
Transformations involve: Transformations involve:
...@@ -193,7 +191,7 @@ class Affine(TransformAugmentorBase): ...@@ -193,7 +191,7 @@ class Affine(TransformAugmentorBase):
super(Affine, self).__init__() super(Affine, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def get_transform(self, img):
if self.scale is not None: if self.scale is not None:
scale = self._rand_range(self.scale[0], self.scale[1]) scale = self._rand_range(self.scale[0], self.scale[1])
else: else:
......
...@@ -5,13 +5,13 @@ ...@@ -5,13 +5,13 @@
import numpy as np import numpy as np
import cv2 import cv2
from .base import ImageAugmentor from .base import PhotometricAugmentor
__all__ = ['Hue', 'Brightness', 'BrightnessScale', 'Contrast', 'MeanVarianceNormalize', __all__ = ['Hue', 'Brightness', 'BrightnessScale', 'Contrast', 'MeanVarianceNormalize',
'GaussianBlur', 'Gamma', 'Clip', 'Saturation', 'Lighting', 'MinMaxNormalize'] 'GaussianBlur', 'Gamma', 'Clip', 'Saturation', 'Lighting', 'MinMaxNormalize']
class Hue(ImageAugmentor): class Hue(PhotometricAugmentor):
""" Randomly change color hue. """ Randomly change color hue.
""" """
...@@ -43,7 +43,7 @@ class Hue(ImageAugmentor): ...@@ -43,7 +43,7 @@ class Hue(ImageAugmentor):
return img return img
class Brightness(ImageAugmentor): class Brightness(PhotometricAugmentor):
""" """
Adjust brightness by adding a random number. Adjust brightness by adding a random number.
""" """
...@@ -51,15 +51,14 @@ class Brightness(ImageAugmentor): ...@@ -51,15 +51,14 @@ class Brightness(ImageAugmentor):
""" """
Args: Args:
delta (float): Randomly add a value within [-delta,delta] delta (float): Randomly add a value within [-delta,delta]
clip (bool): clip results to [0,255] if data type is uint8. clip (bool): clip results to [0,255] even when data type is not uint8.
""" """
super(Brightness, self).__init__() super(Brightness, self).__init__()
assert delta > 0 assert delta > 0
self._init(locals()) self._init(locals())
def _get_augment_params(self, _): def _get_augment_params(self, _):
v = self._rand_range(-self.delta, self.delta) return self._rand_range(-self.delta, self.delta)
return v
def _augment(self, img, v): def _augment(self, img, v):
old_dtype = img.dtype old_dtype = img.dtype
...@@ -70,7 +69,7 @@ class Brightness(ImageAugmentor): ...@@ -70,7 +69,7 @@ class Brightness(ImageAugmentor):
return img.astype(old_dtype) return img.astype(old_dtype)
class BrightnessScale(ImageAugmentor): class BrightnessScale(PhotometricAugmentor):
""" """
Adjust brightness by scaling by a random factor. Adjust brightness by scaling by a random factor.
""" """
...@@ -78,14 +77,13 @@ class BrightnessScale(ImageAugmentor): ...@@ -78,14 +77,13 @@ class BrightnessScale(ImageAugmentor):
""" """
Args: Args:
range (tuple): Randomly scale the image by a factor in (range[0], range[1]) range (tuple): Randomly scale the image by a factor in (range[0], range[1])
clip (bool): clip results to [0,255] if data type is uint8. clip (bool): clip results to [0,255] even when data type is not uint8.
""" """
super(BrightnessScale, self).__init__() super(BrightnessScale, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, _): def _get_augment_params(self, _):
v = self._rand_range(*self.range) return self._rand_range(*self.range)
return v
def _augment(self, img, v): def _augment(self, img, v):
old_dtype = img.dtype old_dtype = img.dtype
...@@ -96,7 +94,7 @@ class BrightnessScale(ImageAugmentor): ...@@ -96,7 +94,7 @@ class BrightnessScale(ImageAugmentor):
return img.astype(old_dtype) return img.astype(old_dtype)
class Contrast(ImageAugmentor): class Contrast(PhotometricAugmentor):
""" """
Apply ``x = (x - mean) * contrast_factor + mean`` to each channel. Apply ``x = (x - mean) * contrast_factor + mean`` to each channel.
""" """
...@@ -106,12 +104,12 @@ class Contrast(ImageAugmentor): ...@@ -106,12 +104,12 @@ class Contrast(ImageAugmentor):
Args: Args:
factor_range (list or tuple): an interval to randomly sample the `contrast_factor`. factor_range (list or tuple): an interval to randomly sample the `contrast_factor`.
rgb (bool or None): if None, use the mean per-channel. rgb (bool or None): if None, use the mean per-channel.
clip (bool): clip to [0, 255] if data type is uint8. clip (bool): clip to [0, 255] even when data type is not uint8.
""" """
super(Contrast, self).__init__() super(Contrast, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, _):
return self._rand_range(*self.factor_range) return self._rand_range(*self.factor_range)
def _augment(self, img, r): def _augment(self, img, r):
...@@ -133,7 +131,7 @@ class Contrast(ImageAugmentor): ...@@ -133,7 +131,7 @@ class Contrast(ImageAugmentor):
return img.astype(old_dtype) return img.astype(old_dtype)
class MeanVarianceNormalize(ImageAugmentor): class MeanVarianceNormalize(PhotometricAugmentor):
""" """
Linearly scales the image to have zero mean and unit norm. Linearly scales the image to have zero mean and unit norm.
``x = (x - mean) / adjusted_stddev`` ``x = (x - mean) / adjusted_stddev``
...@@ -162,7 +160,7 @@ class MeanVarianceNormalize(ImageAugmentor): ...@@ -162,7 +160,7 @@ class MeanVarianceNormalize(ImageAugmentor):
return img return img
class GaussianBlur(ImageAugmentor): class GaussianBlur(PhotometricAugmentor):
""" Gaussian blur the image with random window size""" """ Gaussian blur the image with random window size"""
def __init__(self, max_size=3): def __init__(self, max_size=3):
...@@ -173,7 +171,7 @@ class GaussianBlur(ImageAugmentor): ...@@ -173,7 +171,7 @@ class GaussianBlur(ImageAugmentor):
super(GaussianBlur, self).__init__() super(GaussianBlur, self).__init__()
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def _get_augment_params(self, _):
sx, sy = self.rng.randint(self.max_size, size=(2,)) sx, sy = self.rng.randint(self.max_size, size=(2,))
sx = sx * 2 + 1 sx = sx * 2 + 1
sy = sy * 2 + 1 sy = sy * 2 + 1
...@@ -184,7 +182,7 @@ class GaussianBlur(ImageAugmentor): ...@@ -184,7 +182,7 @@ class GaussianBlur(ImageAugmentor):
borderType=cv2.BORDER_REPLICATE), img.shape) borderType=cv2.BORDER_REPLICATE), img.shape)
class Gamma(ImageAugmentor): class Gamma(PhotometricAugmentor):
""" Randomly adjust gamma """ """ Randomly adjust gamma """
def __init__(self, range=(-0.5, 0.5)): def __init__(self, range=(-0.5, 0.5)):
""" """
...@@ -207,7 +205,7 @@ class Gamma(ImageAugmentor): ...@@ -207,7 +205,7 @@ class Gamma(ImageAugmentor):
return ret return ret
class Clip(ImageAugmentor): class Clip(PhotometricAugmentor):
""" Clip the pixel values """ """ Clip the pixel values """
def __init__(self, min=0, max=255): def __init__(self, min=0, max=255):
...@@ -218,11 +216,10 @@ class Clip(ImageAugmentor): ...@@ -218,11 +216,10 @@ class Clip(ImageAugmentor):
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _augment(self, img, _):
img = np.clip(img, self.min, self.max) return np.clip(img, self.min, self.max)
return img
class Saturation(ImageAugmentor): class Saturation(PhotometricAugmentor):
""" Randomly adjust saturation. """ Randomly adjust saturation.
Follows the implementation in `fb.resnet.torch Follows the implementation in `fb.resnet.torch
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`__. <https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`__.
...@@ -252,7 +249,7 @@ class Saturation(ImageAugmentor): ...@@ -252,7 +249,7 @@ class Saturation(ImageAugmentor):
return ret.astype(old_dtype) return ret.astype(old_dtype)
class Lighting(ImageAugmentor): class Lighting(PhotometricAugmentor):
""" Lighting noise, as in the paper """ Lighting noise, as in the paper
`ImageNet Classification with Deep Convolutional Neural Networks `ImageNet Classification with Deep Convolutional Neural Networks
<https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_. <https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_.
...@@ -267,6 +264,7 @@ class Lighting(ImageAugmentor): ...@@ -267,6 +264,7 @@ class Lighting(ImageAugmentor):
eigval: a vector of (3,). The eigenvalues of 3 channels. eigval: a vector of (3,). The eigenvalues of 3 channels.
eigvec: a 3x3 matrix. Each column is one eigen vector. eigvec: a 3x3 matrix. Each column is one eigen vector.
""" """
super(Lighting, self).__init__()
eigval = np.asarray(eigval) eigval = np.asarray(eigval)
eigvec = np.asarray(eigvec) eigvec = np.asarray(eigvec)
assert eigval.shape == (3,) assert eigval.shape == (3,)
...@@ -275,8 +273,7 @@ class Lighting(ImageAugmentor): ...@@ -275,8 +273,7 @@ class Lighting(ImageAugmentor):
def _get_augment_params(self, img): def _get_augment_params(self, img):
assert img.shape[2] == 3 assert img.shape[2] == 3
ret = self.rng.randn(3) * self.std return (self.rng.randn(3) * self.std).astype("float32")
return ret.astype('float32')
def _augment(self, img, v): def _augment(self, img, v):
old_dtype = img.dtype old_dtype = img.dtype
...@@ -289,7 +286,7 @@ class Lighting(ImageAugmentor): ...@@ -289,7 +286,7 @@ class Lighting(ImageAugmentor):
return img.astype(old_dtype) return img.astype(old_dtype)
class MinMaxNormalize(ImageAugmentor): class MinMaxNormalize(PhotometricAugmentor):
""" """
Linearly scales the image to the range [min, max]. Linearly scales the image to the range [min, max].
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from .base import ImageAugmentor from .base import ImageAugmentor
from .transform import NoOpTransform, TransformList, TransformFactory
__all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug', __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
'RandomOrderAug'] 'RandomOrderAug']
...@@ -10,8 +11,8 @@ __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug', ...@@ -10,8 +11,8 @@ __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
class Identity(ImageAugmentor): class Identity(ImageAugmentor):
""" A no-op augmentor """ """ A no-op augmentor """
def _augment(self, img, _): def get_transform(self, img):
return img return NoOpTransform()
class RandomApplyAug(ImageAugmentor): class RandomApplyAug(ImageAugmentor):
...@@ -22,44 +23,23 @@ class RandomApplyAug(ImageAugmentor): ...@@ -22,44 +23,23 @@ class RandomApplyAug(ImageAugmentor):
def __init__(self, aug, prob): def __init__(self, aug, prob):
""" """
Args: Args:
aug (ImageAugmentor): an augmentor aug (ImageAugmentor): an augmentor.
prob (float): the probability prob (float): the probability to apply the augmentor.
""" """
self._init(locals()) self._init(locals())
super(RandomApplyAug, self).__init__() super(RandomApplyAug, self).__init__()
def _get_augment_params(self, img): def get_transform(self, img):
p = self.rng.rand() p = self.rng.rand()
if p < self.prob: if p < self.prob:
prm = self.aug._get_augment_params(img) return self.aug.get_transform(img)
return (True, prm)
else: else:
return (False, None) return NoOpTransform()
def _augment_return_params(self, img):
p = self.rng.rand()
if p < self.prob:
img, prms = self.aug._augment_return_params(img)
return img, (True, prms)
else:
return img, (False, None)
def reset_state(self): def reset_state(self):
super(RandomApplyAug, self).reset_state() super(RandomApplyAug, self).reset_state()
self.aug.reset_state() self.aug.reset_state()
def _augment(self, img, prm):
if not prm[0]:
return img
else:
return self.aug._augment(img, prm[1])
def _augment_coords(self, coords, prm):
if not prm[0]:
return coords
else:
return self.aug._augment_coords(coords, prm[1])
class RandomChooseAug(ImageAugmentor): class RandomChooseAug(ImageAugmentor):
""" Randomly choose one from a list of augmentors """ """ Randomly choose one from a list of augmentors """
...@@ -82,18 +62,9 @@ class RandomChooseAug(ImageAugmentor): ...@@ -82,18 +62,9 @@ class RandomChooseAug(ImageAugmentor):
for a in self.aug_lists: for a in self.aug_lists:
a.reset_state() a.reset_state()
def _get_augment_params(self, img): def get_transform(self, img):
aug_idx = self.rng.choice(len(self.aug_lists), p=self.prob) aug_idx = self.rng.choice(len(self.aug_lists), p=self.prob)
aug_prm = self.aug_lists[aug_idx]._get_augment_params(img) return self.aug_lists[aug_idx].get_transform(img)
return aug_idx, aug_prm
def _augment(self, img, prm):
idx, prm = prm
return self.aug_lists[idx]._augment(img, prm)
def _augment_coords(self, coords, prm):
idx, prm = prm
return self.aug_lists[idx]._augment_coords(coords, prm)
class RandomOrderAug(ImageAugmentor): class RandomOrderAug(ImageAugmentor):
...@@ -115,30 +86,19 @@ class RandomOrderAug(ImageAugmentor): ...@@ -115,30 +86,19 @@ class RandomOrderAug(ImageAugmentor):
for a in self.aug_lists: for a in self.aug_lists:
a.reset_state() a.reset_state()
def _get_augment_params(self, img): def get_transform(self, img):
# Note: If augmentors change the shape of image, get_augment_param might not work # Note: this makes assumption that the augmentors do not make changes
# All augmentors should only rely on the shape of image # to the image that will affect how the transforms will be instantiated
# in the subsequent augmentors.
idxs = self.rng.permutation(len(self.aug_lists)) idxs = self.rng.permutation(len(self.aug_lists))
prms = [self.aug_lists[k]._get_augment_params(img) tfms = [self.aug_lists[k].get_transform(img)
for k in range(len(self.aug_lists))] for k in range(len(self.aug_lists))]
return idxs, prms return TransformList([tfms[k] for k in idxs])
def _augment(self, img, prm):
idxs, prms = prm
for k in idxs:
img = self.aug_lists[k]._augment(img, prms[k])
return img
def _augment_coords(self, coords, prm):
idxs, prms = prm
for k in idxs:
img = self.aug_lists[k]._augment_coords(coords, prms[k])
return img
class MapImage(ImageAugmentor): class MapImage(ImageAugmentor):
""" """
Map the image array by a function. Map the image array by simple functions.
""" """
def __init__(self, func, coord_func=None): def __init__(self, func, coord_func=None):
...@@ -146,16 +106,14 @@ class MapImage(ImageAugmentor): ...@@ -146,16 +106,14 @@ class MapImage(ImageAugmentor):
Args: Args:
func: a function which takes an image array and return an augmented one func: a function which takes an image array and return an augmented one
coord_func: optional. A function which takes coordinates and return augmented ones. coord_func: optional. A function which takes coordinates and return augmented ones.
Coordinates have the same format as :func:`ImageAugmentor.augment_coords`. Coordinates should be Nx2 array of (x, y)s.
""" """
super(MapImage, self).__init__() super(MapImage, self).__init__()
self.func = func self.func = func
self.coord_func = coord_func self.coord_func = coord_func
def _augment(self, img, _): def get_transform(self, img):
return self.func(img) if self.coord_func:
return TransformFactory(name="MapImage", apply_image=self.func, apply_coords=self.coord_func)
def _augment_coords(self, coords, _): else:
if self.coord_func is None: return TransformFactory(name="MapImage", apply_image=self.func)
raise NotImplementedError
return self.coord_func(coords)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: misc.py # File: misc.py
import numpy as np
import cv2 import cv2
from ...utils import logger from ...utils import logger
from ...utils.argtools import shape2d from ...utils.argtools import shape2d
from .base import ImageAugmentor from .base import ImageAugmentor
from .transform import ResizeTransform, TransformAugmentorBase from .transform import ResizeTransform, NoOpTransform, FlipTransform, TransposeTransform
__all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge', 'Transpose'] __all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge', 'Transpose']
...@@ -27,40 +25,20 @@ class Flip(ImageAugmentor): ...@@ -27,40 +25,20 @@ class Flip(ImageAugmentor):
super(Flip, self).__init__() super(Flip, self).__init__()
if horiz and vert: if horiz and vert:
raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.") raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
elif horiz: if not horiz and not vert:
self.code = 1
elif vert:
self.code = 0
else:
raise ValueError("At least one of horiz or vert has to be True!") raise ValueError("At least one of horiz or vert has to be True!")
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def get_transform(self, img):
h, w = img.shape[:2] h, w = img.shape[:2]
do = self._rand_range() < self.prob do = self._rand_range() < self.prob
return (do, h, w) if not do:
return NoOpTransform()
def _augment(self, img, param):
do, _, _ = param
if do:
ret = cv2.flip(img, self.code)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
else: else:
ret = img return FlipTransform(h, w, self.horiz)
return ret
def _augment_coords(self, coords, param):
do, h, w = param
if do:
if self.code == 0:
coords[:, 1] = h - coords[:, 1]
elif self.code == 1:
coords[:, 0] = w - coords[:, 0]
return coords
class Resize(TransformAugmentorBase): class Resize(ImageAugmentor):
""" Resize image to a target size""" """ Resize image to a target size"""
def __init__(self, shape, interp=cv2.INTER_LINEAR): def __init__(self, shape, interp=cv2.INTER_LINEAR):
...@@ -72,13 +50,13 @@ class Resize(TransformAugmentorBase): ...@@ -72,13 +50,13 @@ class Resize(TransformAugmentorBase):
shape = tuple(shape2d(shape)) shape = tuple(shape2d(shape))
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def get_transform(self, img):
return ResizeTransform( return ResizeTransform(
img.shape[0], img.shape[1], img.shape[0], img.shape[1],
self.shape[0], self.shape[1], self.interp) self.shape[0], self.shape[1], self.interp)
class ResizeShortestEdge(TransformAugmentorBase): class ResizeShortestEdge(ImageAugmentor):
""" """
Resize the shortest edge to a certain number while Resize the shortest edge to a certain number while
keeping the aspect ratio. keeping the aspect ratio.
...@@ -92,18 +70,17 @@ class ResizeShortestEdge(TransformAugmentorBase): ...@@ -92,18 +70,17 @@ class ResizeShortestEdge(TransformAugmentorBase):
size = int(size) size = int(size)
self._init(locals()) self._init(locals())
def _get_augment_params(self, img): def get_transform(self, img):
h, w = img.shape[:2] h, w = img.shape[:2]
scale = self.size * 1.0 / min(h, w) scale = self.size * 1.0 / min(h, w)
if h < w: if h < w:
newh, neww = self.size, int(scale * w + 0.5) newh, neww = self.size, int(scale * w + 0.5)
else: else:
newh, neww = int(scale * h + 0.5), self.size newh, neww = int(scale * h + 0.5), self.size
return ResizeTransform( return ResizeTransform(h, w, newh, neww, self.interp)
h, w, newh, neww, self.interp)
class RandomResize(TransformAugmentorBase): class RandomResize(ImageAugmentor):
""" Randomly rescale width and height of the image.""" """ Randomly rescale width and height of the image."""
def __init__(self, xrange, yrange=None, minimum=(0, 0), aspect_ratio_thres=0.15, def __init__(self, xrange, yrange=None, minimum=(0, 0), aspect_ratio_thres=0.15,
...@@ -137,7 +114,7 @@ class RandomResize(TransformAugmentorBase): ...@@ -137,7 +114,7 @@ class RandomResize(TransformAugmentorBase):
if yrange is not None: if yrange is not None:
logger.warn("aspect_ratio_thres==0, yrange is not used!") logger.warn("aspect_ratio_thres==0, yrange is not used!")
def _get_augment_params(self, img): def get_transform(self, img):
cnt = 0 cnt = 0
h, w = img.shape[:2] h, w = img.shape[:2]
...@@ -186,20 +163,9 @@ class Transpose(ImageAugmentor): ...@@ -186,20 +163,9 @@ class Transpose(ImageAugmentor):
""" """
super(Transpose, self).__init__() super(Transpose, self).__init__()
self.prob = prob self.prob = prob
self._init()
def get_transform(self, _):
def _get_augment_params(self, img): if self.rng.rand() < self.prob:
return self._rand_range() < self.prob return TransposeTransform()
else:
def _augment(self, img, do): return NoOpTransform()
ret = img
if do:
ret = cv2.transpose(img)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, do):
if do:
coords = coords[:, ::-1]
return coords
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
import numpy as np import numpy as np
import cv2 import cv2
from .base import ImageAugmentor from .base import PhotometricAugmentor
__all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise'] __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
class JpegNoise(ImageAugmentor): class JpegNoise(PhotometricAugmentor):
""" Random JPEG noise. """ """ Random JPEG noise. """
def __init__(self, quality_range=(40, 100)): def __init__(self, quality_range=(40, 100)):
...@@ -29,7 +29,7 @@ class JpegNoise(ImageAugmentor): ...@@ -29,7 +29,7 @@ class JpegNoise(ImageAugmentor):
return cv2.imdecode(enc, 1).astype(img.dtype) return cv2.imdecode(enc, 1).astype(img.dtype)
class GaussianNoise(ImageAugmentor): class GaussianNoise(PhotometricAugmentor):
""" """
Add random Gaussian noise N(0, sigma^2) of the same shape to img. Add random Gaussian noise N(0, sigma^2) of the same shape to img.
""" """
...@@ -53,7 +53,7 @@ class GaussianNoise(ImageAugmentor): ...@@ -53,7 +53,7 @@ class GaussianNoise(ImageAugmentor):
return ret.astype(old_dtype) return ret.astype(old_dtype)
class SaltPepperNoise(ImageAugmentor): class SaltPepperNoise(PhotometricAugmentor):
""" Salt and pepper noise. """ Salt and pepper noise.
Randomly set some elements in image to 0 or 255, regardless of its channels. Randomly set some elements in image to 0 or 255, regardless of its channels.
""" """
......
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
from abc import abstractmethod from abc import abstractmethod
from .base import ImageAugmentor from .base import ImageAugmentor
from .transform import TransformFactory
__all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller', __all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
'RandomPaste'] 'RandomPaste']
...@@ -51,6 +52,10 @@ class ConstantBackgroundFiller(BackgroundFiller): ...@@ -51,6 +52,10 @@ class ConstantBackgroundFiller(BackgroundFiller):
return np.zeros(return_shape, dtype=img.dtype) + self.value return np.zeros(return_shape, dtype=img.dtype) + self.value
# NOTE:
# apply_coords should be implemeted in paste transform, but not yet done
class CenterPaste(ImageAugmentor): class CenterPaste(ImageAugmentor):
""" """
Paste the image onto the center of a background canvas. Paste the image onto the center of a background canvas.
...@@ -67,7 +72,10 @@ class CenterPaste(ImageAugmentor): ...@@ -67,7 +72,10 @@ class CenterPaste(ImageAugmentor):
self._init(locals()) self._init(locals())
def _augment(self, img, _): def get_transform(self, _):
return TransformFactory(name=str(self), apply_image=lambda img: self._impl(img))
def _impl(self, img):
img_shape = img.shape[:2] img_shape = img.shape[:2]
assert self.background_shape[0] >= img_shape[0] and self.background_shape[1] >= img_shape[1] assert self.background_shape[0] >= img_shape[0] and self.background_shape[1] >= img_shape[1]
...@@ -78,24 +86,22 @@ class CenterPaste(ImageAugmentor): ...@@ -78,24 +86,22 @@ class CenterPaste(ImageAugmentor):
background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img
return background return background
def _augment_coords(self, coords, param):
raise NotImplementedError()
class RandomPaste(CenterPaste): class RandomPaste(CenterPaste):
""" """
Randomly paste the image onto a background canvas. Randomly paste the image onto a background canvas.
""" """
def _get_augment_params(self, img): def get_transform(self, img):
img_shape = img.shape[:2] img_shape = img.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1] assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
y0 = self._rand_range(self.background_shape[0] - img_shape[0]) y0 = self._rand_range(self.background_shape[0] - img_shape[0])
x0 = self._rand_range(self.background_shape[1] - img_shape[1]) x0 = self._rand_range(self.background_shape[1] - img_shape[1])
return int(x0), int(y0) l = int(x0), int(y0)
return TransformFactory(name=str(self), apply_image=lambda img: self._impl(img, l))
def _augment(self, img, loc): def _impl(self, img, loc):
x0, y0 = loc x0, y0 = loc
img_shape = img.shape[:2] img_shape = img.shape[:2]
background = self.background_filler.fill( background = self.background_filler.fill(
......
This diff is collapsed.
...@@ -10,6 +10,7 @@ import functools ...@@ -10,6 +10,7 @@ import functools
import importlib import importlib
import os import os
import types import types
from collections import defaultdict
from datetime import datetime from datetime import datetime
import six import six
...@@ -75,7 +76,10 @@ def building_rtfd(): ...@@ -75,7 +76,10 @@ def building_rtfd():
or os.environ.get('DOC_BUILDING') or os.environ.get('DOC_BUILDING')
def log_deprecated(name="", text="", eos=""): _DEPRECATED_LOG_NUM = defaultdict(int)
def log_deprecated(name="", text="", eos="", max_num_warnings=None):
""" """
Log deprecation warning. Log deprecation warning.
...@@ -83,6 +87,7 @@ def log_deprecated(name="", text="", eos=""): ...@@ -83,6 +87,7 @@ def log_deprecated(name="", text="", eos=""):
name (str): name of the deprecated item. name (str): name of the deprecated item.
text (str, optional): information about the deprecation. text (str, optional): information about the deprecation.
eos (str, optional): end of service date such as "YYYY-MM-DD". eos (str, optional): end of service date such as "YYYY-MM-DD".
max_num_warnings (int, optional): the maximum number of times to print this warning
""" """
assert name or text assert name or text
if eos: if eos:
...@@ -96,13 +101,18 @@ def log_deprecated(name="", text="", eos=""): ...@@ -96,13 +101,18 @@ def log_deprecated(name="", text="", eos=""):
warn_msg = text warn_msg = text
if eos: if eos:
warn_msg += " Legacy period ends %s" % eos warn_msg += " Legacy period ends %s" % eos
if max_num_warnings is not None:
if _DEPRECATED_LOG_NUM[warn_msg] >= max_num_warnings:
return
_DEPRECATED_LOG_NUM[warn_msg] += 1
logger.warn("[Deprecated] " + warn_msg) logger.warn("[Deprecated] " + warn_msg)
def deprecated(text="", eos=""): def deprecated(text="", eos="", max_num_warnings=None):
""" """
Args: Args:
text, eos: same as :func:`log_deprecated`. text, eos, max_num_warnings: same as :func:`log_deprecated`.
Returns: Returns:
a decorator which deprecates the function. a decorator which deprecates the function.
...@@ -130,7 +140,7 @@ def deprecated(text="", eos=""): ...@@ -130,7 +140,7 @@ def deprecated(text="", eos=""):
@functools.wraps(func) @functools.wraps(func)
def new_func(*args, **kwargs): def new_func(*args, **kwargs):
name = "{} [{}]".format(func.__name__, get_location()) name = "{} [{}]".format(func.__name__, get_location())
log_deprecated(name, text, eos) log_deprecated(name, text, eos, max_num_warnings=max_num_warnings)
return func(*args, **kwargs) return func(*args, **kwargs)
return new_func return new_func
return deprecated_inner return deprecated_inner
......
...@@ -7,7 +7,7 @@ cd $DIR ...@@ -7,7 +7,7 @@ cd $DIR
export TF_CPP_MIN_LOG_LEVEL=2 export TF_CPP_MIN_LOG_LEVEL=2
export TF_CPP_MIN_VLOG_LEVEL=2 export TF_CPP_MIN_VLOG_LEVEL=2
# test import (#471) # test import (#471)
python -c 'from tensorpack.dataflow.imgaug import transform' python -c 'from tensorpack.dataflow import imgaug'
# Check that these private names can be imported because tensorpack is using them # Check that these private names can be imported because tensorpack is using them
python -c "from tensorflow.python.client.session import _FetchHandler" python -c "from tensorflow.python.client.session import _FetchHandler"
python -c "from tensorflow.python.training.monitored_session import _HookedSession" python -c "from tensorflow.python.training.monitored_session import _HookedSession"
...@@ -16,6 +16,7 @@ python -c "import tensorflow as tf; tf.Operation._add_control_input" ...@@ -16,6 +16,7 @@ python -c "import tensorflow as tf; tf.Operation._add_control_input"
# run tests # run tests
python -m tensorpack.callbacks.param_test python -m tensorpack.callbacks.param_test
python -m tensorpack.tfutils.unit_tests python -m tensorpack.tfutils.unit_tests
python -m unittest tensorpack.dataflow.imgaug._test
TENSORPACK_SERIALIZE=pyarrow python test_serializer.py TENSORPACK_SERIALIZE=pyarrow python test_serializer.py
TENSORPACK_SERIALIZE=msgpack python test_serializer.py TENSORPACK_SERIALIZE=msgpack python test_serializer.py
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment