Commit 831d0dda authored by Yuxin Wu's avatar Yuxin Wu

Wrap albu/albumentations (#399)

parent a9864bf0
...@@ -26,12 +26,12 @@ ON_RTD = (os.environ.get('READTHEDOCS') == 'True') ...@@ -26,12 +26,12 @@ ON_RTD = (os.environ.get('READTHEDOCS') == 'True')
MOCK_MODULES = ['tabulate', 'h5py', MOCK_MODULES = ['tabulate', 'h5py',
'cv2', 'zmq', 'lmdb', 'cv2', 'zmq', 'lmdb',
'msgpack', 'msgpack_numpy', 'pyarrow',
'sklearn', 'sklearn.datasets', 'sklearn', 'sklearn.datasets',
'scipy', 'scipy.misc', 'scipy.io', 'scipy', 'scipy.misc', 'scipy.io',
'tornado', 'tornado.concurrent', 'tornado', 'tornado.concurrent',
'horovod', 'horovod.tensorflow', 'horovod', 'horovod.tensorflow',
'subprocess32', 'functools32', 'subprocess32', 'functools32']
'imgaug']
# it's better to have tensorflow installed (for some docs to show) # it's better to have tensorflow installed (for some docs to show)
# but it's OK to mock it as well # but it's OK to mock it as well
......
...@@ -3,7 +3,9 @@ tensorpack.dataflow.imgaug package ...@@ -3,7 +3,9 @@ tensorpack.dataflow.imgaug package
This package contains Tensorpack's augmentors. This package contains Tensorpack's augmentors.
Note that other image augmentation libraries can be wrapped into Tensorpack's interface as well. Note that other image augmentation libraries can be wrapped into Tensorpack's interface as well.
For example, see `imgaug.IAAugmentor <#tensorpack.dataflow.imgaug.IAAugmentor>`_. For example, `imgaug.IAAugmentor <#tensorpack.dataflow.imgaug.IAAugmentor>`_
and `imgaug.Albumentations <#tensorpack.dataflow.imgaug.Albumentations`_
wrap two popular image augmentation libraries.
.. container:: custom-index .. container:: custom-index
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
from .base import ImageAugmentor from .base import ImageAugmentor
__all__ = ['IAAugmentor'] __all__ = ['IAAugmentor', 'Albumentations']
class IAAugmentor(ImageAugmentor): class IAAugmentor(ImageAugmentor):
...@@ -38,6 +38,7 @@ class IAAugmentor(ImageAugmentor): ...@@ -38,6 +38,7 @@ class IAAugmentor(ImageAugmentor):
return aug.augment_image(img) return aug.augment_image(img)
def _augment_coords(self, coords, param): def _augment_coords(self, coords, param):
import imgaug as IA
aug, shape = param aug, shape = param
points = [IA.Keypoint(x=x, y=y) for x, y in coords] points = [IA.Keypoint(x=x, y=y) for x, y in coords]
points = IA.KeypointsOnImage(points, shape=shape) points = IA.KeypointsOnImage(points, shape=shape)
...@@ -45,8 +46,24 @@ class IAAugmentor(ImageAugmentor): ...@@ -45,8 +46,24 @@ class IAAugmentor(ImageAugmentor):
return np.asarray([[p.x, p.y] for p in augmented]) return np.asarray([[p.x, p.y] for p in augmented])
from ...utils.develop import create_dummy_class # noqa class Albumentations(ImageAugmentor):
try: """
import imgaug as IA Wrap an augmentor form the albumentations library: https://github.com/albu/albumentations
except ImportError: Coordinate augmentation is not supported by the library.
IAAugmentor = create_dummy_class('IAAugmentor', 'imgaug') # noqa """
def __init__(self, augmentor):
"""
Args:
augmentor (albumentations.BasicTransform):
"""
super(Albumentations, self).__init__()
self._aug = augmentor
def _get_augment_params(self, img):
return self._aug.get_params()
def _augment(self, img, param):
return self._aug.apply(img, **param)
def _augment_coords(self, coords, param):
raise NotImplementedError()
...@@ -29,6 +29,7 @@ def create_dummy_class(klass, dependency): ...@@ -29,6 +29,7 @@ def create_dummy_class(klass, dependency):
Returns: Returns:
class: a class object class: a class object
""" """
assert not building_rtfd()
class _DummyMetaClass(type): class _DummyMetaClass(type):
# throw error on class attribute access # throw error on class attribute access
...@@ -55,6 +56,8 @@ def create_dummy_func(func, dependency): ...@@ -55,6 +56,8 @@ def create_dummy_func(func, dependency):
Returns: Returns:
function: a function object function: a function object
""" """
assert not building_rtfd()
if isinstance(dependency, (list, tuple)): if isinstance(dependency, (list, tuple)):
dependency = ','.join(dependency) dependency = ','.join(dependency)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment