Commit 8c0106e4 authored by Yuxin Wu's avatar Yuxin Wu

update docs in imgaug

parent 34357e77
......@@ -12,6 +12,12 @@ from ..utils.argtools import shape2d
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents']
def _valid_coords(coords):
assert coords.ndim == 2, coords.ndim
assert coords.shape[1] == 2, coords.shape
assert np.issubdtype(coords.dtype, np.float), coords.dtype
class ImageFromFile(RNGDataFlow):
""" Produce images read from a list of files. """
def __init__(self, files, channel=3, resize=None, shuffle=False):
......@@ -49,14 +55,14 @@ class ImageFromFile(RNGDataFlow):
class AugmentImageComponent(MapDataComponent):
"""
Apply image augmentors on 1 component.
Apply image augmentors on 1 image component.
"""
def __init__(self, ds, augmentors, index=0, copy=True):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented.
index (int): the index of the image component to be augmented in the datapoint.
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
......@@ -117,9 +123,7 @@ class AugmentImageCoordinates(MapData):
def func(dp):
try:
img, coords = dp[img_index], dp[coords_index]
assert coords.ndim == 2, coords.ndim
assert coords.shape[1] == 2, coords.shape
assert np.issubdtype(coords.dtype, np.float), coords.dtype
_valid_coords(coords)
if copy:
img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img)
......@@ -145,14 +149,25 @@ class AugmentImageCoordinates(MapData):
class AugmentImageComponents(MapData):
"""
Apply image augmentors on several components, with shared augmentation parameters.
Example:
.. code-block:: python
ds = MyDataFlow() # produce [image(HWC), segmask(HW), keypoint(Nx2)]
ds = AugmentImageComponents(
ds, augs,
index=(0,1), coords_index=(2,))
"""
def __init__(self, ds, augmentors, index=(0, 1), copy=True):
def __init__(self, ds, augmentors, index=(0, 1), coords_index=(), copy=True):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of components.
index: tuple of indices of the image components.
coords_index: tuple of indices of the coordinates components.
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
......@@ -169,11 +184,16 @@ class AugmentImageComponents(MapData):
dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
try:
im = copy_func(dp[index[0]])
major_image = index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image])
im, prms = self.augs._augment_return_params(im)
dp[index[0]] = im
dp[major_image] = im
for idx in index[1:]:
dp[idx] = self.augs._augment(copy_func(dp[idx]), prms)
for idx in coords_index:
coords = copy_func(dp[idx])
_valid_coords(coords)
dp[idx] = self.augs._augment_coords(coords, prms)
return dp
except KeyboardInterrupt:
raise
......
......@@ -44,16 +44,20 @@ class Augmentor(object):
@abstractmethod
def _augment(self, d, param):
"""
augment with the given param and return the new image
Augment with the given param and return the new data.
The augmentor is allowed to modify data in-place.
"""
def _get_augment_params(self, d):
"""
get the augmentor parameters
Get the augmentor parameters.
"""
return None
def _rand_range(self, low=1.0, high=None, size=None):
"""
Uniform float random number between low and high.
"""
if high is None:
low, high = 0, low
if size is None:
......@@ -64,9 +68,15 @@ class Augmentor(object):
class ImageAugmentor(Augmentor):
def _augment_coords(self, coords, param):
"""
Augment the coordinates given the param.
By default, keeps coordinates unchanged.
If a subclass changes coordinates but couldn't implement this method,
it should ``raise NotImplementedError()``.
Args:
coords: Nx2 floating point nparray where each row is (x, y)
Returns:
new coords
"""
return coords
......@@ -86,7 +96,7 @@ class AugmentorList(ImageAugmentor):
def _get_augment_params(self, img):
# the next augmentor requires the previous one to finish
raise RuntimeError("Cannot simply get parameters of a AugmentorList!")
raise RuntimeError("Cannot simply get all parameters of a AugmentorList without running the augmentation!")
def _augment_return_params(self, img):
assert img.ndim in [2, 3], img.ndim
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment