Commit 8c0106e4 authored by Yuxin Wu's avatar Yuxin Wu

update docs in imgaug

parent 34357e77
...@@ -12,6 +12,12 @@ from ..utils.argtools import shape2d ...@@ -12,6 +12,12 @@ from ..utils.argtools import shape2d
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents'] __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents']
def _valid_coords(coords):
assert coords.ndim == 2, coords.ndim
assert coords.shape[1] == 2, coords.shape
assert np.issubdtype(coords.dtype, np.float), coords.dtype
class ImageFromFile(RNGDataFlow): class ImageFromFile(RNGDataFlow):
""" Produce images read from a list of files. """ """ Produce images read from a list of files. """
def __init__(self, files, channel=3, resize=None, shuffle=False): def __init__(self, files, channel=3, resize=None, shuffle=False):
...@@ -49,14 +55,14 @@ class ImageFromFile(RNGDataFlow): ...@@ -49,14 +55,14 @@ class ImageFromFile(RNGDataFlow):
class AugmentImageComponent(MapDataComponent): class AugmentImageComponent(MapDataComponent):
""" """
Apply image augmentors on 1 component. Apply image augmentors on 1 image component.
""" """
def __init__(self, ds, augmentors, index=0, copy=True): def __init__(self, ds, augmentors, index=0, copy=True):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented. index (int): the index of the image component to be augmented in the datapoint.
copy (bool): Some augmentors modify the input images. When copy is copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied, True, a copy will be made before any augmentors are applied,
to keep the original images not modified. to keep the original images not modified.
...@@ -117,9 +123,7 @@ class AugmentImageCoordinates(MapData): ...@@ -117,9 +123,7 @@ class AugmentImageCoordinates(MapData):
def func(dp): def func(dp):
try: try:
img, coords = dp[img_index], dp[coords_index] img, coords = dp[img_index], dp[coords_index]
assert coords.ndim == 2, coords.ndim _valid_coords(coords)
assert coords.shape[1] == 2, coords.shape
assert np.issubdtype(coords.dtype, np.float), coords.dtype
if copy: if copy:
img, coords = copy_mod.deepcopy((img, coords)) img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img) img, prms = self.augs._augment_return_params(img)
...@@ -145,14 +149,25 @@ class AugmentImageCoordinates(MapData): ...@@ -145,14 +149,25 @@ class AugmentImageCoordinates(MapData):
class AugmentImageComponents(MapData): class AugmentImageComponents(MapData):
""" """
Apply image augmentors on several components, with shared augmentation parameters. Apply image augmentors on several components, with shared augmentation parameters.
Example:
.. code-block:: python
ds = MyDataFlow() # produce [image(HWC), segmask(HW), keypoint(Nx2)]
ds = AugmentImageComponents(
ds, augs,
index=(0,1), coords_index=(2,))
""" """
def __init__(self, ds, augmentors, index=(0, 1), copy=True): def __init__(self, ds, augmentors, index=(0, 1), coords_index=(), copy=True):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of components. index: tuple of indices of the image components.
coords_index: tuple of indices of the coordinates components.
copy (bool): Some augmentors modify the input images. When copy is copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied, True, a copy will be made before any augmentors are applied,
to keep the original images not modified. to keep the original images not modified.
...@@ -169,11 +184,16 @@ class AugmentImageComponents(MapData): ...@@ -169,11 +184,16 @@ class AugmentImageComponents(MapData):
dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
try: try:
im = copy_func(dp[index[0]]) major_image = index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image])
im, prms = self.augs._augment_return_params(im) im, prms = self.augs._augment_return_params(im)
dp[index[0]] = im dp[major_image] = im
for idx in index[1:]: for idx in index[1:]:
dp[idx] = self.augs._augment(copy_func(dp[idx]), prms) dp[idx] = self.augs._augment(copy_func(dp[idx]), prms)
for idx in coords_index:
coords = copy_func(dp[idx])
_valid_coords(coords)
dp[idx] = self.augs._augment_coords(coords, prms)
return dp return dp
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
......
...@@ -44,16 +44,20 @@ class Augmentor(object): ...@@ -44,16 +44,20 @@ class Augmentor(object):
@abstractmethod @abstractmethod
def _augment(self, d, param): def _augment(self, d, param):
""" """
augment with the given param and return the new image Augment with the given param and return the new data.
The augmentor is allowed to modify data in-place.
""" """
def _get_augment_params(self, d): def _get_augment_params(self, d):
""" """
get the augmentor parameters Get the augmentor parameters.
""" """
return None return None
def _rand_range(self, low=1.0, high=None, size=None): def _rand_range(self, low=1.0, high=None, size=None):
"""
Uniform float random number between low and high.
"""
if high is None: if high is None:
low, high = 0, low low, high = 0, low
if size is None: if size is None:
...@@ -64,9 +68,15 @@ class Augmentor(object): ...@@ -64,9 +68,15 @@ class Augmentor(object):
class ImageAugmentor(Augmentor): class ImageAugmentor(Augmentor):
def _augment_coords(self, coords, param): def _augment_coords(self, coords, param):
""" """
Augment the coordinates given the param.
By default, keeps coordinates unchanged. By default, keeps coordinates unchanged.
If a subclass changes coordinates but couldn't implement this method, If a subclass changes coordinates but couldn't implement this method,
it should ``raise NotImplementedError()``. it should ``raise NotImplementedError()``.
Args:
coords: Nx2 floating point nparray where each row is (x, y)
Returns:
new coords
""" """
return coords return coords
...@@ -86,7 +96,7 @@ class AugmentorList(ImageAugmentor): ...@@ -86,7 +96,7 @@ class AugmentorList(ImageAugmentor):
def _get_augment_params(self, img): def _get_augment_params(self, img):
# the next augmentor requires the previous one to finish # the next augmentor requires the previous one to finish
raise RuntimeError("Cannot simply get parameters of a AugmentorList!") raise RuntimeError("Cannot simply get all parameters of a AugmentorList without running the augmentation!")
def _augment_return_params(self, img): def _augment_return_params(self, img):
assert img.ndim in [2, 3], img.ndim assert img.ndim in [2, 3], img.ndim
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment