Commit 831d0dda authored by Yuxin Wu's avatar Yuxin Wu

Wrap albu/albumentations (#399)

parent a9864bf0
......@@ -26,12 +26,12 @@ ON_RTD = (os.environ.get('READTHEDOCS') == 'True')
MOCK_MODULES = ['tabulate', 'h5py',
'cv2', 'zmq', 'lmdb',
'msgpack', 'msgpack_numpy', 'pyarrow',
'sklearn', 'sklearn.datasets',
'scipy', 'scipy.misc', 'scipy.io',
'tornado', 'tornado.concurrent',
'horovod', 'horovod.tensorflow',
'subprocess32', 'functools32',
'imgaug']
'subprocess32', 'functools32']
# it's better to have tensorflow installed (for some docs to show)
# but it's OK to mock it as well
......
......@@ -3,7 +3,9 @@ tensorpack.dataflow.imgaug package
This package contains Tensorpack's augmentors.
Note that other image augmentation libraries can be wrapped into Tensorpack's interface as well.
For example, see `imgaug.IAAugmentor <#tensorpack.dataflow.imgaug.IAAugmentor>`_.
For example, `imgaug.IAAugmentor <#tensorpack.dataflow.imgaug.IAAugmentor>`_
and `imgaug.Albumentations <#tensorpack.dataflow.imgaug.Albumentations`_
wrap two popular image augmentation libraries.
.. container:: custom-index
......
......@@ -5,7 +5,7 @@ import numpy as np
from .base import ImageAugmentor
__all__ = ['IAAugmentor']
__all__ = ['IAAugmentor', 'Albumentations']
class IAAugmentor(ImageAugmentor):
......@@ -38,6 +38,7 @@ class IAAugmentor(ImageAugmentor):
return aug.augment_image(img)
def _augment_coords(self, coords, param):
import imgaug as IA
aug, shape = param
points = [IA.Keypoint(x=x, y=y) for x, y in coords]
points = IA.KeypointsOnImage(points, shape=shape)
......@@ -45,8 +46,24 @@ class IAAugmentor(ImageAugmentor):
return np.asarray([[p.x, p.y] for p in augmented])
from ...utils.develop import create_dummy_class # noqa
try:
import imgaug as IA
except ImportError:
IAAugmentor = create_dummy_class('IAAugmentor', 'imgaug') # noqa
class Albumentations(ImageAugmentor):
"""
Wrap an augmentor form the albumentations library: https://github.com/albu/albumentations
Coordinate augmentation is not supported by the library.
"""
def __init__(self, augmentor):
"""
Args:
augmentor (albumentations.BasicTransform):
"""
super(Albumentations, self).__init__()
self._aug = augmentor
def _get_augment_params(self, img):
return self._aug.get_params()
def _augment(self, img, param):
return self._aug.apply(img, **param)
def _augment_coords(self, coords, param):
raise NotImplementedError()
......@@ -29,6 +29,7 @@ def create_dummy_class(klass, dependency):
Returns:
class: a class object
"""
assert not building_rtfd()
class _DummyMetaClass(type):
# throw error on class attribute access
......@@ -55,6 +56,8 @@ def create_dummy_func(func, dependency):
Returns:
function: a function object
"""
assert not building_rtfd()
if isinstance(dependency, (list, tuple)):
dependency = ','.join(dependency)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment