Commit c9fde630 authored by Yuxin Wu's avatar Yuxin Wu Committed by GitHub

Make augmentors return a `Transform` instance. (#1290)

* pure-image can run

* external; tests

* keep backward compatibility

* Remove the need for "augment_return_xxx" by using `LazyTransform`.

* better printing and docs

* Let PhotometricAugmentor's interface similar to the old one

* update docs

* udpate examples

* update tests & warnings
parent 5af84e93
......@@ -376,6 +376,7 @@ _DEPRECATED_NAMES = set([
'DistributedTrainerParameterServer',
'InputDesc',
'inputs_desc',
'Augmentor',
# renamed items that should not appear in docs
'DumpTensor',
......
......@@ -23,22 +23,7 @@ as the DataFlow.
In other words, for simple mapping you do not need to write an augmentor.
An augmentor may do something more than just applying a mapping.
To do complicated augmentation, the interface you will need to implement is:
```python
class MyAug(imgaug.ImageAugmentor):
def _get_augment_params(self, img):
# Generated random params with self.rng
return params
def _augment(self, img, params):
return augmented_img
# optional method
def _augment_coords(self, coords, param):
# coords is a Nx2 floating point array, each row is (x, y)
return augmented_coords
```
To do custom augmentation, you can implement one yourself.
#### The Design of imgaug Module
......@@ -46,29 +31,59 @@ class MyAug(imgaug.ImageAugmentor):
The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the following usage:
* Factor out randomness and determinism.
An augmentor may be randomized, but you can call
[augment_return_params](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.augment_return_params)
to obtain the randomized parameters and then call
[augment_with_params](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.augment_with_params)
on other data with the same randomized parameters.
* Because of the above reason, tensorpack's augmentor can augment multiple images together
easily. This is commonly used for augmenting an image together with its masks.
* An image augmentor (e.g. flip) may also augment a coordinate, with
[augment_coords](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.ImageAugmentor.augment_coords).
In this way, images can be augmented together with
boxes, polygons, keypoints, etc.
Coordinate augmentation enforces floating points coordinates
An augmentor often contains randomized policy, e.g., it randomly perturbs each image differently.
However, its "deterministic" part needs to be factored out, so that
the same transformation can be re-applied to other data
assocaited with the image. This is achieved like this:
```python
tfm = augmentor.get_transform(img) # a deterministic transformation
new_img = tfm.apply_image(img)
new_img2 = tfm.apply_image(img2)
new_coords = tfm.apply_coords(coords)
```
Due to this design, it can augment images together with its annotations
(e.g., segmentation masks, bounding boxes, keypoints).
Our coordinate augmentation enforces floating points coordinates
to avoid quantization error.
When you don't need to re-apply the same transformation, you can also just call
```python
new_img = augmentor.augment(img)
```
* Reset random seed. Random seed can be reset by
[reset_state](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.reset_state).
[reset_state](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.ImageAugmentor.reset_state).
This is important for multi-process data loading, to make sure different
processes get different seeds.
The reset method is called automatically if you use tensorpack's
[image augmentation dataflow](../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent).
[image augmentation dataflow](../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent)
or if you use Python 3.7+.
Otherwise, **you are responsible** for calling it by yourself in subprocesses.
See the
[API documentation](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.reset_state)
[API documentation](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.ImageAugmentor.reset_state)
of this method for more details.
### Write an Augmentor
The interface you will need to implement is:
```python
class MyAug(imgaug.ImageAugmentor):
def get_transform(self, img):
# Randomly generate a deterministic transformation, to be applied on img
x = random_parameters()
return MyTransform(x)
class MyTransform(imgaug.Transform):
def apply_image(self, img):
return new_img
def apply_coords(self, coords):
return new_coords
```
Check out the zoo of builtin augmentors to have a better sense.
......@@ -5,7 +5,7 @@ import numpy as np
import cv2
from tensorpack.dataflow import RNGDataFlow
from tensorpack.dataflow.imgaug import transform
from tensorpack.dataflow.imgaug import ImageAugmentor, ResizeTransform
class DataFromListOfDict(RNGDataFlow):
......@@ -26,7 +26,7 @@ class DataFromListOfDict(RNGDataFlow):
yield dp
class CustomResize(transform.TransformAugmentorBase):
class CustomResize(ImageAugmentor):
"""
Try resizing the shortest edge to a certain number
while avoiding the longest edge to exceed max_size.
......@@ -44,7 +44,7 @@ class CustomResize(transform.TransformAugmentorBase):
short_edge_length = (short_edge_length, short_edge_length)
self._init(locals())
def _get_augment_params(self, img):
def get_transform(self, img):
h, w = img.shape[:2]
size = self.rng.randint(
self.short_edge_length[0], self.short_edge_length[1] + 1)
......@@ -59,7 +59,7 @@ class CustomResize(transform.TransformAugmentorBase):
neww = neww * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return transform.ResizeTransform(h, w, newh, neww, self.interp)
return ResizeTransform(h, w, newh, neww, self.interp)
def box_to_point8(boxes):
......
......@@ -90,9 +90,10 @@ class TrainingDataPreprocessor:
boxes[:, 1::2] *= height
# augmentation:
im, params = self.aug.augment_return_params(im)
tfms = self.aug.get_transform(im)
im = tfms.apply_image(im)
points = box_to_point8(boxes)
points = self.aug.augment_coords(points, params)
points = tfms.apply_coords(points)
boxes = point8_to_box(points)
if len(boxes):
assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!"
......@@ -131,7 +132,7 @@ class TrainingDataPreprocessor:
for polys in segmentation:
if not self.cfg.DATA.ABSOLUTE_COORD:
polys = [p * width_height for p in polys]
polys = [self.aug.augment_coords(p, params) for p in polys]
polys = [tfms.apply_coords(p) for p in polys]
masks.append(segmentation_to_mask(polys, im.shape[0], gt_mask_width))
if len(masks):
......
......@@ -191,7 +191,7 @@ def get_data(name):
ds = dataset.BSDS500(name, shuffle=True)
class CropMultiple16(imgaug.ImageAugmentor):
def _get_augment_params(self, img):
def get_transform(self, img):
newh = img.shape[0] // 16 * 16
neww = img.shape[1] // 16 * 16
assert newh > 0 and neww > 0
......@@ -199,11 +199,7 @@ def get_data(name):
h0 = 0 if diffh == 0 else self.rng.randint(diffh)
diffw = img.shape[1] - neww
w0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (h0, w0, newh, neww)
def _augment(self, img, param):
h0, w0, newh, neww = param
return img[h0:h0 + newh, w0:w0 + neww]
return imgaug.CropTransform(h0, w0, newh, neww)
if isTrain:
shape_aug = [
......
......@@ -161,10 +161,9 @@ class AugmentImageCoordinates(MapData):
validate_coords(coords)
if self._copy:
img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs.augment_return_params(img)
dp[self._img_index] = img
coords = self.augs.augment_coords(coords, prms)
dp[self._coords_index] = coords
tfms = self.augs.get_transform(img)
dp[self._img_index] = tfms.apply_image(img)
dp[self._coords_index] = tfms.apply_coords(coords)
return dp
......@@ -207,15 +206,15 @@ class AugmentImageComponents(MapData):
major_image = index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image])
check_dtype(im)
im, prms = self.augs.augment_return_params(im)
dp[major_image] = im
tfms = self.augs.get_transform(im)
dp[major_image] = tfms.apply_image(im)
for idx in index[1:]:
check_dtype(dp[idx])
dp[idx] = self.augs.augment_with_params(copy_func(dp[idx]), prms)
dp[idx] = tfms.apply_image(copy_func(dp[idx]))
for idx in coords_index:
coords = copy_func(dp[idx])
validate_coords(coords)
dp[idx] = self.augs.augment_coords(coords, prms)
dp[idx] = tfms.apply_coords(coords)
return dp
super(AugmentImageComponents, self).__init__(ds, func)
......
......@@ -3,29 +3,170 @@
import sys
import numpy as np
import cv2
import unittest
from . import AugmentorList
from .crop import *
from .deform import *
from .imgproc import *
from .base import ImageAugmentor, AugmentorList
from .imgproc import Contrast
from .noise import SaltPepperNoise
from .noname import *
from .misc import Flip, Resize
anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([
def _rand_image(shape=(20, 20)):
return np.random.rand(*shape).astype("float32")
class LegacyBrightness(ImageAugmentor):
def __init__(self, delta, clip=True):
super(LegacyBrightness, self).__init__()
assert delta > 0
self._init(locals())
def _get_augment_params(self, _):
v = self._rand_range(-self.delta, self.delta)
return v
def _augment(self, img, v):
old_dtype = img.dtype
img = img.astype('float32')
img += v
if self.clip or old_dtype == np.uint8:
img = np.clip(img, 0, 255)
return img.astype(old_dtype)
class LegacyFlip(ImageAugmentor):
def __init__(self, horiz=False, vert=False, prob=0.5):
super(LegacyFlip, self).__init__()
if horiz and vert:
raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
elif horiz:
self.code = 1
elif vert:
self.code = 0
else:
raise ValueError("At least one of horiz or vert has to be True!")
self._init(locals())
def _get_augment_params(self, img):
h, w = img.shape[:2]
do = self._rand_range() < self.prob
return (do, h, w)
def _augment(self, img, param):
do, _, _ = param
if do:
ret = cv2.flip(img, self.code)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
else:
ret = img
return ret
def _augment_coords(self, coords, param):
do, h, w = param
if do:
if self.code == 0:
coords[:, 1] = h - coords[:, 1]
elif self.code == 1:
coords[:, 0] = w - coords[:, 0]
return coords
class ImgAugTest(unittest.TestCase):
def _get_augs(self):
return AugmentorList([
Contrast((0.8, 1.2)),
Flip(horiz=True),
Resize((30, 30)),
SaltPepperNoise()
])
def _get_augs_with_legacy(self):
return AugmentorList([
LegacyBrightness(0.5),
LegacyFlip(horiz=True),
Resize((30, 30)),
SaltPepperNoise()
])
def test_augmentors(self):
augmentors = self._get_augs()
img = _rand_image()
orig = img.copy()
tfms = augmentors.get_transform(img)
# test printing
print(augmentors)
print(tfms)
newimg = tfms.apply_image(img)
print(tfms) # lazy ones will instantiate after the first apply
newimg2 = tfms.apply_image(orig)
self.assertTrue(np.allclose(newimg, newimg2))
self.assertEqual(newimg2.shape[0], 30)
coords = np.asarray([[0, 0], [10, 12]], dtype="float32")
tfms.apply_coords(coords)
def test_legacy_usage(self):
augmentors = self._get_augs()
img = _rand_image()
orig = img.copy()
newimg, tfms = augmentors.augment_return_params(img)
newimg2 = augmentors.augment_with_params(orig, tfms)
self.assertTrue(np.allclose(newimg, newimg2))
self.assertEqual(newimg2.shape[0], 30)
coords = np.asarray([[0, 0], [10, 12]], dtype="float32")
augmentors.augment_coords(coords, tfms)
def test_legacy_augs_new_usage(self):
augmentors = self._get_augs_with_legacy()
img = _rand_image()
orig = img.copy()
tfms = augmentors.get_transform(img)
newimg = tfms.apply_image(img)
newimg2 = tfms.apply_image(orig)
self.assertTrue(np.allclose(newimg, newimg2))
self.assertEqual(newimg2.shape[0], 30)
coords = np.asarray([[0, 0], [10, 12]], dtype="float32")
tfms.apply_coords(coords)
def test_legacy_augs_legacy_usage(self):
augmentors = self._get_augs_with_legacy()
img = _rand_image()
orig = img.copy()
newimg, tfms = augmentors.augment_return_params(img)
newimg2 = augmentors.augment_with_params(orig, tfms)
self.assertTrue(np.allclose(newimg, newimg2))
self.assertEqual(newimg2.shape[0], 30)
coords = np.asarray([[0, 0], [10, 12]], dtype="float32")
augmentors.augment_coords(coords, tfms)
if __name__ == '__main__':
anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([
Contrast((0.8, 1.2)),
Flip(horiz=True),
GaussianDeform(anchors, (360, 480), 0.2, randrange=20),
# RandomCropRandomShape(0.3),
SaltPepperNoise()
])
])
img = cv2.imread(sys.argv[1])
newimg, prms = augmentors._augment_return_params(img)
cv2.imshow(" ", newimg.astype('uint8'))
cv2.waitKey()
img = cv2.imread(sys.argv[1])
newimg, prms = augmentors._augment_return_params(img)
cv2.imshow(" ", newimg.astype('uint8'))
cv2.waitKey()
newimg = augmentors._augment(img, prms)
cv2.imshow(" ", newimg.astype('uint8'))
cv2.waitKey()
newimg = augmentors._augment(img, prms)
cv2.imshow(" ", newimg.astype('uint8'))
cv2.waitKey()
This diff is collapsed.
......@@ -4,13 +4,12 @@
import numpy as np
import cv2
from .base import ImageAugmentor
from .meta import MapImage
from .base import PhotometricAugmentor
__all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']
class ColorSpace(ImageAugmentor):
class ColorSpace(PhotometricAugmentor):
""" Convert into another color space. """
def __init__(self, mode, keepdims=True):
......@@ -20,6 +19,7 @@ class ColorSpace(ImageAugmentor):
keepdims (bool): keep the dimension of image unchanged if OpenCV
changes it.
"""
super(ColorSpace, self).__init__()
self._init(locals())
def _augment(self, img, _):
......@@ -43,13 +43,13 @@ class Grayscale(ColorSpace):
super(Grayscale, self).__init__(mode, keepdims)
class ToUint8(MapImage):
class ToUint8(PhotometricAugmentor):
""" Convert image to uint8. Useful to reduce communication overhead. """
def __init__(self):
super(ToUint8, self).__init__(lambda x: np.clip(x, 0, 255).astype(np.uint8), lambda x: x)
def _augment(self, img, _):
return np.clip(img, 0, 255).astype(np.uint8)
class ToFloat32(MapImage):
class ToFloat32(PhotometricAugmentor):
""" Convert image to float32, may increase quality of the augmentor. """
def __init__(self):
super(ToFloat32, self).__init__(lambda x: x.astype(np.float32), lambda x: x)
def _augment(self, img, _):
return img.astype(np.float32)
......@@ -6,14 +6,14 @@ import cv2
from ...utils.argtools import shape2d
from ...utils.develop import log_deprecated
from .base import ImageAugmentor
from .transform import CropTransform, TransformAugmentorBase
from .base import ImageAugmentor, ImagePlaceholder
from .transform import CropTransform, TransformList, ResizeTransform
from .misc import ResizeShortestEdge
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape', 'GoogleNetRandomCropAndResize']
class RandomCrop(TransformAugmentorBase):
class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """
def __init__(self, crop_shape):
......@@ -25,7 +25,7 @@ class RandomCrop(TransformAugmentorBase):
super(RandomCrop, self).__init__()
self._init(locals())
def _get_augment_params(self, img):
def get_transform(self, img):
orig_shape = img.shape
assert orig_shape[0] >= self.crop_shape[0] \
and orig_shape[1] >= self.crop_shape[1], orig_shape
......@@ -36,7 +36,7 @@ class RandomCrop(TransformAugmentorBase):
return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1])
class CenterCrop(TransformAugmentorBase):
class CenterCrop(ImageAugmentor):
""" Crop the image at the center"""
def __init__(self, crop_shape):
......@@ -47,7 +47,7 @@ class CenterCrop(TransformAugmentorBase):
crop_shape = shape2d(crop_shape)
self._init(locals())
def _get_augment_params(self, img):
def get_transform(self, img):
orig_shape = img.shape
assert orig_shape[0] >= self.crop_shape[0] \
and orig_shape[1] >= self.crop_shape[1], orig_shape
......@@ -56,7 +56,7 @@ class CenterCrop(TransformAugmentorBase):
return CropTransform(h0, w0, self.crop_shape[0], self.crop_shape[1])
class RandomCropRandomShape(TransformAugmentorBase):
class RandomCropRandomShape(ImageAugmentor):
""" Random crop with a random shape"""
def __init__(self, wmin, hmin,
......@@ -70,18 +70,19 @@ class RandomCropRandomShape(TransformAugmentorBase):
wmin, hmin, wmax, hmax: range to sample shape.
max_aspect_ratio (float): this argument has no effect and is deprecated.
"""
super(RandomCropRandomShape, self).__init__()
if max_aspect_ratio is not None:
log_deprecated("RandomCropRandomShape(max_aspect_ratio)", "It is never implemented!", "2020-06-06")
self._init(locals())
def _get_augment_params(self, img):
def get_transform(self, img):
hmax = self.hmax or img.shape[0]
wmax = self.wmax or img.shape[1]
h = self.rng.randint(self.hmin, hmax + 1)
w = self.rng.randint(self.wmin, wmax + 1)
diffh = img.shape[0] - h
diffw = img.shape[1] - w
assert diffh >= 0 and diffw >= 0
assert diffh >= 0 and diffw >= 0, str(diffh) + ", " + str(diffw)
y0 = 0 if diffh == 0 else self.rng.randint(diffh)
x0 = 0 if diffw == 0 else self.rng.randint(diffw)
return CropTransform(y0, x0, h, w)
......@@ -106,9 +107,10 @@ class GoogleNetRandomCropAndResize(ImageAugmentor):
aspect_ratio_range (tuple(float)): Defaults to make aspect ratio in 3/4-4/3.
target_shape (int): Defaults to 224, the standard ImageNet image shape.
"""
super(GoogleNetRandomCropAndResize, self).__init__()
self._init(locals())
def _augment(self, img, _):
def get_transform(self, img):
h, w = img.shape[:2]
area = h * w
for _ in range(10):
......@@ -121,12 +123,11 @@ class GoogleNetRandomCropAndResize(ImageAugmentor):
if hh <= h and ww <= w:
x1 = 0 if w == ww else self.rng.randint(0, w - ww)
y1 = 0 if h == hh else self.rng.randint(0, h - hh)
out = img[y1:y1 + hh, x1:x1 + ww]
out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=self.interp)
return out
out = ResizeShortestEdge(self.target_shape, interp=self.interp).augment(img)
out = CenterCrop(self.target_shape).augment(out)
return out
def _augment_coords(self, coords, param):
raise NotImplementedError()
return TransformList([
CropTransform(y1, x1, hh, ww),
ResizeTransform(hh, ww, self.target_shape, self.target_shape, interp=self.interp)
])
tfm1 = ResizeShortestEdge(self.target_shape, interp=self.interp).get_transform(img)
out_shape = (tfm1.new_h, tfm1.new_w)
tfm2 = CenterCrop(self.target_shape).get_transform(ImagePlaceholder(shape=out_shape))
return TransformList([tfm1, tfm2])
......@@ -6,6 +6,7 @@ import numpy as np
from ...utils import logger
from .base import ImageAugmentor
from .transform import TransformFactory
__all__ = []
......@@ -97,14 +98,11 @@ class GaussianDeform(ImageAugmentor):
self.randrange = randrange
self.sigma = sigma
def _get_augment_params(self, img):
def get_transform(self, img):
v = self.rng.rand(self.K, 2).astype('float32') - 0.5
v = v * 2 * self.randrange
return v
return TransformFactory(name=str(self), apply_image=lambda img: self._augment(img, v))
def _augment(self, img, v):
grid = self.grid + np.dot(self.gws, v)
return np_sample(img, grid)
def _augment_coords(self, coords, param):
raise NotImplementedError()
......@@ -3,10 +3,26 @@
import numpy as np
from .base import ImageAugmentor
from .transform import Transform
__all__ = ['IAAugmentor', 'Albumentations']
class IAATransform(Transform):
def __init__(self, aug, img_shape):
self._init(locals())
def apply_image(self, img):
return self.aug.augment_image(img)
def apply_coords(self, coords):
import imgaug as IA
points = [IA.Keypoint(x=x, y=y) for x, y in coords]
points = IA.KeypointsOnImage(points, shape=self.img_shape)
augmented = self.aug.augment_keypoints([points])[0].keypoints
return np.asarray([[p.x, p.y] for p in augmented])
class IAAugmentor(ImageAugmentor):
"""
Wrap an augmentor form the IAA library: https://github.com/aleju/imgaug.
......@@ -43,20 +59,16 @@ class IAAugmentor(ImageAugmentor):
super(IAAugmentor, self).__init__()
self._aug = augmentor
def _get_augment_params(self, img):
return (self._aug.to_deterministic(), img.shape)
def get_transform(self, img):
return IAATransform(self._aug.to_deterministic(), img.shape)
def _augment(self, img, param):
aug, _ = param
return aug.augment_image(img)
def _augment_coords(self, coords, param):
import imgaug as IA
aug, shape = param
points = [IA.Keypoint(x=x, y=y) for x, y in coords]
points = IA.KeypointsOnImage(points, shape=shape)
augmented = aug.augment_keypoints([points])[0].keypoints
return np.asarray([[p.x, p.y] for p in augmented])
class AlbumentationsTransform(Transform):
def __init__(self, aug, param):
self._init(locals())
def apply_image(self, img):
return self.aug.apply(img, **self.param)
class Albumentations(ImageAugmentor):
......@@ -81,11 +93,5 @@ class Albumentations(ImageAugmentor):
super(Albumentations, self).__init__()
self._aug = augmentor
def _get_augment_params(self, img):
return self._aug.get_params()
def _augment(self, img, param):
return self._aug.apply(img, **param)
def _augment_coords(self, coords, param):
raise NotImplementedError()
def get_transform(self, img):
return AlbumentationsTransform(self._aug, self._aug.get_params())
......@@ -7,12 +7,12 @@ import numpy as np
import cv2
from .base import ImageAugmentor
from .transform import TransformAugmentorBase, WarpAffineTransform
from .transform import WarpAffineTransform, CropTransform, TransformList
__all__ = ['Shift', 'Rotation', 'RotationAndCropValid', 'Affine']
class Shift(TransformAugmentorBase):
class Shift(ImageAugmentor):
""" Random horizontal and vertical shifts """
def __init__(self, horiz_frac=0, vert_frac=0,
......@@ -28,7 +28,7 @@ class Shift(TransformAugmentorBase):
super(Shift, self).__init__()
self._init(locals())
def _get_augment_params(self, img):
def get_transform(self, img):
max_dx = self.horiz_frac * img.shape[1]
max_dy = self.vert_frac * img.shape[0]
dx = np.round(self._rand_range(-max_dx, max_dx))
......@@ -40,7 +40,7 @@ class Shift(TransformAugmentorBase):
borderMode=self.border, borderValue=self.border_value)
class Rotation(TransformAugmentorBase):
class Rotation(ImageAugmentor):
""" Random rotate the image w.r.t a random center"""
def __init__(self, max_deg, center_range=(0, 1),
......@@ -61,7 +61,7 @@ class Rotation(TransformAugmentorBase):
super(Rotation, self).__init__()
self._init(locals())
def _get_augment_params(self, img):
def get_transform(self, img):
center = img.shape[1::-1] * self._rand_range(
self.center_range[0], self.center_range[1], (2,))
deg = self._rand_range(-self.max_deg, self.max_deg)
......@@ -100,29 +100,27 @@ class RotationAndCropValid(ImageAugmentor):
super(RotationAndCropValid, self).__init__()
self._init(locals())
def _get_augment_params(self, img):
def _get_deg(self, img):
deg = self._rand_range(-self.max_deg, self.max_deg)
if self.step_deg:
deg = deg // self.step_deg * self.step_deg
return deg
def _augment(self, img, deg):
def get_transform(self, img):
deg = self._get_deg(img)
h, w = img.shape[:2]
center = (img.shape[1] * 0.5, img.shape[0] * 0.5)
rot_m = cv2.getRotationMatrix2D((center[0] - 0.5, center[1] - 0.5), deg, 1)
ret = cv2.warpAffine(img, rot_m, img.shape[1::-1],
flags=self.interp, borderMode=cv2.BORDER_CONSTANT)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
neww, newh = RotationAndCropValid.largest_rotated_rect(ret.shape[1], ret.shape[0], deg)
neww = min(neww, ret.shape[1])
newh = min(newh, ret.shape[0])
tfm = WarpAffineTransform(rot_m, (w, h), interp=self.interp)
neww, newh = RotationAndCropValid.largest_rotated_rect(w, h, deg)
neww = min(neww, w)
newh = min(newh, h)
newx = int(center[0] - neww * 0.5)
newy = int(center[1] - newh * 0.5)
# print(ret.shape, deg, newx, newy, neww, newh)
return ret[newy:newy + newh, newx:newx + neww]
def _augment_coords(self, coords, param):
raise NotImplementedError()
tfm2 = CropTransform(newy, newx, newh, neww)
return TransformList([tfm, tfm2])
@staticmethod
def largest_rotated_rect(w, h, angle):
......@@ -152,7 +150,7 @@ class RotationAndCropValid(ImageAugmentor):
return int(np.round(wr)), int(np.round(hr))
class Affine(TransformAugmentorBase):
class Affine(ImageAugmentor):
"""
Random affine transform of the image w.r.t to the image center.
Transformations involve:
......@@ -193,7 +191,7 @@ class Affine(TransformAugmentorBase):
super(Affine, self).__init__()
self._init(locals())
def _get_augment_params(self, img):
def get_transform(self, img):
if self.scale is not None:
scale = self._rand_range(self.scale[0], self.scale[1])
else:
......
......@@ -5,13 +5,13 @@
import numpy as np
import cv2
from .base import ImageAugmentor
from .base import PhotometricAugmentor
__all__ = ['Hue', 'Brightness', 'BrightnessScale', 'Contrast', 'MeanVarianceNormalize',
'GaussianBlur', 'Gamma', 'Clip', 'Saturation', 'Lighting', 'MinMaxNormalize']
class Hue(ImageAugmentor):
class Hue(PhotometricAugmentor):
""" Randomly change color hue.
"""
......@@ -43,7 +43,7 @@ class Hue(ImageAugmentor):
return img
class Brightness(ImageAugmentor):
class Brightness(PhotometricAugmentor):
"""
Adjust brightness by adding a random number.
"""
......@@ -51,15 +51,14 @@ class Brightness(ImageAugmentor):
"""
Args:
delta (float): Randomly add a value within [-delta,delta]
clip (bool): clip results to [0,255] if data type is uint8.
clip (bool): clip results to [0,255] even when data type is not uint8.
"""
super(Brightness, self).__init__()
assert delta > 0
self._init(locals())
def _get_augment_params(self, _):
v = self._rand_range(-self.delta, self.delta)
return v
return self._rand_range(-self.delta, self.delta)
def _augment(self, img, v):
old_dtype = img.dtype
......@@ -70,7 +69,7 @@ class Brightness(ImageAugmentor):
return img.astype(old_dtype)
class BrightnessScale(ImageAugmentor):
class BrightnessScale(PhotometricAugmentor):
"""
Adjust brightness by scaling by a random factor.
"""
......@@ -78,14 +77,13 @@ class BrightnessScale(ImageAugmentor):
"""
Args:
range (tuple): Randomly scale the image by a factor in (range[0], range[1])
clip (bool): clip results to [0,255] if data type is uint8.
clip (bool): clip results to [0,255] even when data type is not uint8.
"""
super(BrightnessScale, self).__init__()
self._init(locals())
def _get_augment_params(self, _):
v = self._rand_range(*self.range)
return v
return self._rand_range(*self.range)
def _augment(self, img, v):
old_dtype = img.dtype
......@@ -96,7 +94,7 @@ class BrightnessScale(ImageAugmentor):
return img.astype(old_dtype)
class Contrast(ImageAugmentor):
class Contrast(PhotometricAugmentor):
"""
Apply ``x = (x - mean) * contrast_factor + mean`` to each channel.
"""
......@@ -106,12 +104,12 @@ class Contrast(ImageAugmentor):
Args:
factor_range (list or tuple): an interval to randomly sample the `contrast_factor`.
rgb (bool or None): if None, use the mean per-channel.
clip (bool): clip to [0, 255] if data type is uint8.
clip (bool): clip to [0, 255] even when data type is not uint8.
"""
super(Contrast, self).__init__()
self._init(locals())
def _get_augment_params(self, img):
def _get_augment_params(self, _):
return self._rand_range(*self.factor_range)
def _augment(self, img, r):
......@@ -133,7 +131,7 @@ class Contrast(ImageAugmentor):
return img.astype(old_dtype)
class MeanVarianceNormalize(ImageAugmentor):
class MeanVarianceNormalize(PhotometricAugmentor):
"""
Linearly scales the image to have zero mean and unit norm.
``x = (x - mean) / adjusted_stddev``
......@@ -162,7 +160,7 @@ class MeanVarianceNormalize(ImageAugmentor):
return img
class GaussianBlur(ImageAugmentor):
class GaussianBlur(PhotometricAugmentor):
""" Gaussian blur the image with random window size"""
def __init__(self, max_size=3):
......@@ -173,7 +171,7 @@ class GaussianBlur(ImageAugmentor):
super(GaussianBlur, self).__init__()
self._init(locals())
def _get_augment_params(self, img):
def _get_augment_params(self, _):
sx, sy = self.rng.randint(self.max_size, size=(2,))
sx = sx * 2 + 1
sy = sy * 2 + 1
......@@ -184,7 +182,7 @@ class GaussianBlur(ImageAugmentor):
borderType=cv2.BORDER_REPLICATE), img.shape)
class Gamma(ImageAugmentor):
class Gamma(PhotometricAugmentor):
""" Randomly adjust gamma """
def __init__(self, range=(-0.5, 0.5)):
"""
......@@ -207,7 +205,7 @@ class Gamma(ImageAugmentor):
return ret
class Clip(ImageAugmentor):
class Clip(PhotometricAugmentor):
""" Clip the pixel values """
def __init__(self, min=0, max=255):
......@@ -218,11 +216,10 @@ class Clip(ImageAugmentor):
self._init(locals())
def _augment(self, img, _):
img = np.clip(img, self.min, self.max)
return img
return np.clip(img, self.min, self.max)
class Saturation(ImageAugmentor):
class Saturation(PhotometricAugmentor):
""" Randomly adjust saturation.
Follows the implementation in `fb.resnet.torch
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`__.
......@@ -252,7 +249,7 @@ class Saturation(ImageAugmentor):
return ret.astype(old_dtype)
class Lighting(ImageAugmentor):
class Lighting(PhotometricAugmentor):
""" Lighting noise, as in the paper
`ImageNet Classification with Deep Convolutional Neural Networks
<https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_.
......@@ -267,6 +264,7 @@ class Lighting(ImageAugmentor):
eigval: a vector of (3,). The eigenvalues of 3 channels.
eigvec: a 3x3 matrix. Each column is one eigen vector.
"""
super(Lighting, self).__init__()
eigval = np.asarray(eigval)
eigvec = np.asarray(eigvec)
assert eigval.shape == (3,)
......@@ -275,8 +273,7 @@ class Lighting(ImageAugmentor):
def _get_augment_params(self, img):
assert img.shape[2] == 3
ret = self.rng.randn(3) * self.std
return ret.astype('float32')
return (self.rng.randn(3) * self.std).astype("float32")
def _augment(self, img, v):
old_dtype = img.dtype
......@@ -289,7 +286,7 @@ class Lighting(ImageAugmentor):
return img.astype(old_dtype)
class MinMaxNormalize(ImageAugmentor):
class MinMaxNormalize(PhotometricAugmentor):
"""
Linearly scales the image to the range [min, max].
......
......@@ -3,6 +3,7 @@
from .base import ImageAugmentor
from .transform import NoOpTransform, TransformList, TransformFactory
__all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
'RandomOrderAug']
......@@ -10,8 +11,8 @@ __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
class Identity(ImageAugmentor):
""" A no-op augmentor """
def _augment(self, img, _):
return img
def get_transform(self, img):
return NoOpTransform()
class RandomApplyAug(ImageAugmentor):
......@@ -22,44 +23,23 @@ class RandomApplyAug(ImageAugmentor):
def __init__(self, aug, prob):
"""
Args:
aug (ImageAugmentor): an augmentor
prob (float): the probability
aug (ImageAugmentor): an augmentor.
prob (float): the probability to apply the augmentor.
"""
self._init(locals())
super(RandomApplyAug, self).__init__()
def _get_augment_params(self, img):
def get_transform(self, img):
p = self.rng.rand()
if p < self.prob:
prm = self.aug._get_augment_params(img)
return (True, prm)
return self.aug.get_transform(img)
else:
return (False, None)
def _augment_return_params(self, img):
p = self.rng.rand()
if p < self.prob:
img, prms = self.aug._augment_return_params(img)
return img, (True, prms)
else:
return img, (False, None)
return NoOpTransform()
def reset_state(self):
super(RandomApplyAug, self).reset_state()
self.aug.reset_state()
def _augment(self, img, prm):
if not prm[0]:
return img
else:
return self.aug._augment(img, prm[1])
def _augment_coords(self, coords, prm):
if not prm[0]:
return coords
else:
return self.aug._augment_coords(coords, prm[1])
class RandomChooseAug(ImageAugmentor):
""" Randomly choose one from a list of augmentors """
......@@ -82,18 +62,9 @@ class RandomChooseAug(ImageAugmentor):
for a in self.aug_lists:
a.reset_state()
def _get_augment_params(self, img):
def get_transform(self, img):
aug_idx = self.rng.choice(len(self.aug_lists), p=self.prob)
aug_prm = self.aug_lists[aug_idx]._get_augment_params(img)
return aug_idx, aug_prm
def _augment(self, img, prm):
idx, prm = prm
return self.aug_lists[idx]._augment(img, prm)
def _augment_coords(self, coords, prm):
idx, prm = prm
return self.aug_lists[idx]._augment_coords(coords, prm)
return self.aug_lists[aug_idx].get_transform(img)
class RandomOrderAug(ImageAugmentor):
......@@ -115,30 +86,19 @@ class RandomOrderAug(ImageAugmentor):
for a in self.aug_lists:
a.reset_state()
def _get_augment_params(self, img):
# Note: If augmentors change the shape of image, get_augment_param might not work
# All augmentors should only rely on the shape of image
def get_transform(self, img):
# Note: this makes assumption that the augmentors do not make changes
# to the image that will affect how the transforms will be instantiated
# in the subsequent augmentors.
idxs = self.rng.permutation(len(self.aug_lists))
prms = [self.aug_lists[k]._get_augment_params(img)
tfms = [self.aug_lists[k].get_transform(img)
for k in range(len(self.aug_lists))]
return idxs, prms
def _augment(self, img, prm):
idxs, prms = prm
for k in idxs:
img = self.aug_lists[k]._augment(img, prms[k])
return img
def _augment_coords(self, coords, prm):
idxs, prms = prm
for k in idxs:
img = self.aug_lists[k]._augment_coords(coords, prms[k])
return img
return TransformList([tfms[k] for k in idxs])
class MapImage(ImageAugmentor):
"""
Map the image array by a function.
Map the image array by simple functions.
"""
def __init__(self, func, coord_func=None):
......@@ -146,16 +106,14 @@ class MapImage(ImageAugmentor):
Args:
func: a function which takes an image array and return an augmented one
coord_func: optional. A function which takes coordinates and return augmented ones.
Coordinates have the same format as :func:`ImageAugmentor.augment_coords`.
Coordinates should be Nx2 array of (x, y)s.
"""
super(MapImage, self).__init__()
self.func = func
self.coord_func = coord_func
def _augment(self, img, _):
return self.func(img)
def _augment_coords(self, coords, _):
if self.coord_func is None:
raise NotImplementedError
return self.coord_func(coords)
def get_transform(self, img):
if self.coord_func:
return TransformFactory(name="MapImage", apply_image=self.func, apply_coords=self.coord_func)
else:
return TransformFactory(name="MapImage", apply_image=self.func)
# -*- coding: utf-8 -*-
# File: misc.py
import numpy as np
import cv2
from ...utils import logger
from ...utils.argtools import shape2d
from .base import ImageAugmentor
from .transform import ResizeTransform, TransformAugmentorBase
from .transform import ResizeTransform, NoOpTransform, FlipTransform, TransposeTransform
__all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge', 'Transpose']
......@@ -27,40 +25,20 @@ class Flip(ImageAugmentor):
super(Flip, self).__init__()
if horiz and vert:
raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
elif horiz:
self.code = 1
elif vert:
self.code = 0
else:
if not horiz and not vert:
raise ValueError("At least one of horiz or vert has to be True!")
self._init(locals())
def _get_augment_params(self, img):
def get_transform(self, img):
h, w = img.shape[:2]
do = self._rand_range() < self.prob
return (do, h, w)
def _augment(self, img, param):
do, _, _ = param
if do:
ret = cv2.flip(img, self.code)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
if not do:
return NoOpTransform()
else:
ret = img
return ret
def _augment_coords(self, coords, param):
do, h, w = param
if do:
if self.code == 0:
coords[:, 1] = h - coords[:, 1]
elif self.code == 1:
coords[:, 0] = w - coords[:, 0]
return coords
return FlipTransform(h, w, self.horiz)
class Resize(TransformAugmentorBase):
class Resize(ImageAugmentor):
""" Resize image to a target size"""
def __init__(self, shape, interp=cv2.INTER_LINEAR):
......@@ -72,13 +50,13 @@ class Resize(TransformAugmentorBase):
shape = tuple(shape2d(shape))
self._init(locals())
def _get_augment_params(self, img):
def get_transform(self, img):
return ResizeTransform(
img.shape[0], img.shape[1],
self.shape[0], self.shape[1], self.interp)
class ResizeShortestEdge(TransformAugmentorBase):
class ResizeShortestEdge(ImageAugmentor):
"""
Resize the shortest edge to a certain number while
keeping the aspect ratio.
......@@ -92,18 +70,17 @@ class ResizeShortestEdge(TransformAugmentorBase):
size = int(size)
self._init(locals())
def _get_augment_params(self, img):
def get_transform(self, img):
h, w = img.shape[:2]
scale = self.size * 1.0 / min(h, w)
if h < w:
newh, neww = self.size, int(scale * w + 0.5)
else:
newh, neww = int(scale * h + 0.5), self.size
return ResizeTransform(
h, w, newh, neww, self.interp)
return ResizeTransform(h, w, newh, neww, self.interp)
class RandomResize(TransformAugmentorBase):
class RandomResize(ImageAugmentor):
""" Randomly rescale width and height of the image."""
def __init__(self, xrange, yrange=None, minimum=(0, 0), aspect_ratio_thres=0.15,
......@@ -137,7 +114,7 @@ class RandomResize(TransformAugmentorBase):
if yrange is not None:
logger.warn("aspect_ratio_thres==0, yrange is not used!")
def _get_augment_params(self, img):
def get_transform(self, img):
cnt = 0
h, w = img.shape[:2]
......@@ -186,20 +163,9 @@ class Transpose(ImageAugmentor):
"""
super(Transpose, self).__init__()
self.prob = prob
self._init()
def _get_augment_params(self, img):
return self._rand_range() < self.prob
def _augment(self, img, do):
ret = img
if do:
ret = cv2.transpose(img)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def _augment_coords(self, coords, do):
if do:
coords = coords[:, ::-1]
return coords
def get_transform(self, _):
if self.rng.rand() < self.prob:
return TransposeTransform()
else:
return NoOpTransform()
......@@ -5,12 +5,12 @@
import numpy as np
import cv2
from .base import ImageAugmentor
from .base import PhotometricAugmentor
__all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
class JpegNoise(ImageAugmentor):
class JpegNoise(PhotometricAugmentor):
""" Random JPEG noise. """
def __init__(self, quality_range=(40, 100)):
......@@ -29,7 +29,7 @@ class JpegNoise(ImageAugmentor):
return cv2.imdecode(enc, 1).astype(img.dtype)
class GaussianNoise(ImageAugmentor):
class GaussianNoise(PhotometricAugmentor):
"""
Add random Gaussian noise N(0, sigma^2) of the same shape to img.
"""
......@@ -53,7 +53,7 @@ class GaussianNoise(ImageAugmentor):
return ret.astype(old_dtype)
class SaltPepperNoise(ImageAugmentor):
class SaltPepperNoise(PhotometricAugmentor):
""" Salt and pepper noise.
Randomly set some elements in image to 0 or 255, regardless of its channels.
"""
......
......@@ -6,6 +6,7 @@ import numpy as np
from abc import abstractmethod
from .base import ImageAugmentor
from .transform import TransformFactory
__all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
'RandomPaste']
......@@ -51,6 +52,10 @@ class ConstantBackgroundFiller(BackgroundFiller):
return np.zeros(return_shape, dtype=img.dtype) + self.value
# NOTE:
# apply_coords should be implemeted in paste transform, but not yet done
class CenterPaste(ImageAugmentor):
"""
Paste the image onto the center of a background canvas.
......@@ -67,7 +72,10 @@ class CenterPaste(ImageAugmentor):
self._init(locals())
def _augment(self, img, _):
def get_transform(self, _):
return TransformFactory(name=str(self), apply_image=lambda img: self._impl(img))
def _impl(self, img):
img_shape = img.shape[:2]
assert self.background_shape[0] >= img_shape[0] and self.background_shape[1] >= img_shape[1]
......@@ -78,24 +86,22 @@ class CenterPaste(ImageAugmentor):
background[y0:y0 + img_shape[0], x0:x0 + img_shape[1]] = img
return background
def _augment_coords(self, coords, param):
raise NotImplementedError()
class RandomPaste(CenterPaste):
"""
Randomly paste the image onto a background canvas.
"""
def _get_augment_params(self, img):
def get_transform(self, img):
img_shape = img.shape[:2]
assert self.background_shape[0] > img_shape[0] and self.background_shape[1] > img_shape[1]
y0 = self._rand_range(self.background_shape[0] - img_shape[0])
x0 = self._rand_range(self.background_shape[1] - img_shape[1])
return int(x0), int(y0)
l = int(x0), int(y0)
return TransformFactory(name=str(self), apply_image=lambda img: self._impl(img, l))
def _augment(self, img, loc):
def _impl(self, img, loc):
x0, y0 = loc
img_shape = img.shape[:2]
background = self.background_filler.fill(
......
This diff is collapsed.
......@@ -10,6 +10,7 @@ import functools
import importlib
import os
import types
from collections import defaultdict
from datetime import datetime
import six
......@@ -75,7 +76,10 @@ def building_rtfd():
or os.environ.get('DOC_BUILDING')
def log_deprecated(name="", text="", eos=""):
_DEPRECATED_LOG_NUM = defaultdict(int)
def log_deprecated(name="", text="", eos="", max_num_warnings=None):
"""
Log deprecation warning.
......@@ -83,6 +87,7 @@ def log_deprecated(name="", text="", eos=""):
name (str): name of the deprecated item.
text (str, optional): information about the deprecation.
eos (str, optional): end of service date such as "YYYY-MM-DD".
max_num_warnings (int, optional): the maximum number of times to print this warning
"""
assert name or text
if eos:
......@@ -96,13 +101,18 @@ def log_deprecated(name="", text="", eos=""):
warn_msg = text
if eos:
warn_msg += " Legacy period ends %s" % eos
if max_num_warnings is not None:
if _DEPRECATED_LOG_NUM[warn_msg] >= max_num_warnings:
return
_DEPRECATED_LOG_NUM[warn_msg] += 1
logger.warn("[Deprecated] " + warn_msg)
def deprecated(text="", eos=""):
def deprecated(text="", eos="", max_num_warnings=None):
"""
Args:
text, eos: same as :func:`log_deprecated`.
text, eos, max_num_warnings: same as :func:`log_deprecated`.
Returns:
a decorator which deprecates the function.
......@@ -130,7 +140,7 @@ def deprecated(text="", eos=""):
@functools.wraps(func)
def new_func(*args, **kwargs):
name = "{} [{}]".format(func.__name__, get_location())
log_deprecated(name, text, eos)
log_deprecated(name, text, eos, max_num_warnings=max_num_warnings)
return func(*args, **kwargs)
return new_func
return deprecated_inner
......
......@@ -7,7 +7,7 @@ cd $DIR
export TF_CPP_MIN_LOG_LEVEL=2
export TF_CPP_MIN_VLOG_LEVEL=2
# test import (#471)
python -c 'from tensorpack.dataflow.imgaug import transform'
python -c 'from tensorpack.dataflow import imgaug'
# Check that these private names can be imported because tensorpack is using them
python -c "from tensorflow.python.client.session import _FetchHandler"
python -c "from tensorflow.python.training.monitored_session import _HookedSession"
......@@ -16,6 +16,7 @@ python -c "import tensorflow as tf; tf.Operation._add_control_input"
# run tests
python -m tensorpack.callbacks.param_test
python -m tensorpack.tfutils.unit_tests
python -m unittest tensorpack.dataflow.imgaug._test
TENSORPACK_SERIALIZE=pyarrow python test_serializer.py
TENSORPACK_SERIALIZE=msgpack python test_serializer.py
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment