Commit 13e2a3e3 authored by Yuxin Wu's avatar Yuxin Wu

augment with func

parent caa93135
...@@ -6,7 +6,7 @@ from abc import abstractmethod, ABCMeta ...@@ -6,7 +6,7 @@ from abc import abstractmethod, ABCMeta
from ...utils import get_rng from ...utils import get_rng
from six.moves import zip from six.moves import zip
__all__ = ['ImageAugmentor', 'AugmentorList'] __all__ = ['ImageAugmentor', 'AugmentorList', 'AugmentWithFunc']
class ImageAugmentor(object): class ImageAugmentor(object):
""" Base class for an image augmentor""" """ Base class for an image augmentor"""
...@@ -63,6 +63,14 @@ class ImageAugmentor(object): ...@@ -63,6 +63,14 @@ class ImageAugmentor(object):
size = [] size = []
return low + self.rng.rand(*size) * (high - low) return low + self.rng.rand(*size) * (high - low)
class AugmentWithFunc(ImageAugmentor):
""" func: takes an image and return an image"""
def __init__(self, func):
self.func = func
def _augment(self, img, _):
return self.func(img)
class AugmentorList(ImageAugmentor): class AugmentorList(ImageAugmentor):
""" """
Augment by a list of augmentors Augment by a list of augmentors
...@@ -98,3 +106,5 @@ class AugmentorList(ImageAugmentor): ...@@ -98,3 +106,5 @@ class AugmentorList(ImageAugmentor):
""" Will reset state of each augmentor """ """ Will reset state of each augmentor """
for a in self.augs: for a in self.augs:
a.reset_state() a.reset_state()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment