Commit d08c9c5a authored by Yuxin Wu's avatar Yuxin Wu

__repr__ for augmentors (fix #388)

parent 3ed43ab4
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import inspect
import pprint
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ...utils.utils import get_rng from ...utils.utils import get_rng
import six import six
...@@ -64,6 +66,31 @@ class Augmentor(object): ...@@ -64,6 +66,31 @@ class Augmentor(object):
size = [] size = []
return self.rng.uniform(low, high, size) return self.rng.uniform(low, high, size)
def __repr__(self):
"""
Produce something like:
"imgaug.MyAugmentor(field1={self.field1}, field2={self.field2})"
"""
argspec = inspect.getargspec(self.__init__)
assert argspec.varargs is None, "The default __repr__ doesn't work for vaargs!"
assert argspec.keywords is None, "The default __repr__ doesn't work for kwargs!"
fields = argspec.args[1:]
index_field_has_default = len(fields) - (0 if argspec.defaults is None else len(argspec.defaults))
classname = type(self).__name__
argstr = []
for idx, f in enumerate(fields):
assert hasattr(self, f), \
"Attribute {} not found! The default __repr__ only works if attributes match the constructor.".format(f)
attr = getattr(self, f)
if idx >= index_field_has_default:
if attr is argspec.defaults[idx - index_field_has_default]:
continue
argstr.append("{}={}".format(f, pprint.pformat(attr)))
return "imgaug.{}({})".format(classname, ', '.join(argstr))
__str__ = __repr__
class ImageAugmentor(Augmentor): class ImageAugmentor(Augmentor):
def _augment_coords(self, coords, param): def _augment_coords(self, coords, param):
......
...@@ -12,6 +12,7 @@ __all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32'] ...@@ -12,6 +12,7 @@ __all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']
class ColorSpace(ImageAugmentor): class ColorSpace(ImageAugmentor):
""" Convert into another colorspace. """ """ Convert into another colorspace. """
def __init__(self, mode, keepdims=True): def __init__(self, mode, keepdims=True):
""" """
Args: Args:
......
...@@ -90,6 +90,7 @@ class GaussianDeform(ImageAugmentor): ...@@ -90,6 +90,7 @@ class GaussianDeform(ImageAugmentor):
self.randrange = self.shape[0] / 8 self.randrange = self.shape[0] / 8
else: else:
self.randrange = randrange self.randrange = randrange
self.sigma = sigma
def _get_augment_params(self, img): def _get_augment_params(self, img):
v = self.rng.rand(self.K, 2).astype('float32') - 0.5 v = self.rng.rand(self.K, 2).astype('float32') - 0.5
......
...@@ -148,6 +148,7 @@ class MapImage(ImageAugmentor): ...@@ -148,6 +148,7 @@ class MapImage(ImageAugmentor):
Args: Args:
func: a function which takes an image array and return an augmented one func: a function which takes an image array and return an augmented one
""" """
super(MapImage, self).__init__()
self.func = func self.func = func
self.coord_func = coord_func self.coord_func = coord_func
......
...@@ -33,8 +33,7 @@ class Flip(ImageAugmentor): ...@@ -33,8 +33,7 @@ class Flip(ImageAugmentor):
self.code = 0 self.code = 0
else: else:
raise ValueError("At least one of horiz or vert has to be True!") raise ValueError("At least one of horiz or vert has to be True!")
self.prob = prob self._init(locals())
self._init()
def _get_augment_params(self, img): def _get_augment_params(self, img):
h, w = img.shape[:2] h, w = img.shape[:2]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment