Commit d8a647f2 authored by Yuxin Wu's avatar Yuxin Wu

reset state in augmentors

parent 2a60316c
......@@ -39,7 +39,7 @@ class ImageFromFile(DataFlow):
class AugmentImageComponent(ProxyDataFlow):
"""
Augment the image in each data point
Augment image in each data point
Args:
ds: a DataFlow dataset instance
augmentors: a list of ImageAugmentor instance
......@@ -52,7 +52,7 @@ class AugmentImageComponent(ProxyDataFlow):
def reset_state(self):
self.ds.reset_state()
# TODO aug reset
self.augs.reset_state()
def get_data(self):
for dp in self.ds.get_data():
......
......@@ -20,15 +20,18 @@ class ImageAugmentor(object):
__metaclass__ = ABCMeta
def __init__(self):
self.rng = get_rng(self)
self.reset_state()
def _init(self, params=None):
self.rng = get_rng(self)
self.reset_state()
if params:
for k, v in params.iteritems():
if k != 'self':
setattr(self, k, v)
def reset_state(self):
self.rng = get_rng(self)
def augment(self, img):
"""
Note: will both modify `img` in-place and return `img`
......@@ -64,3 +67,7 @@ class AugmentorList(ImageAugmentor):
img.arr = img.arr.astype('float32')
for aug in self.augs:
aug.augment(img)
def reset_state(self):
for a in self.augs:
a.reset_state()
......@@ -61,6 +61,7 @@ class GaussianDeform(ImageAugmentor):
shape: 2D image shape
randrange: default to shape[0] / 8
"""
super(GaussianDeform, self).__init__()
self.anchors = anchors
self.K = len(self.anchors)
self.shape = shape
......@@ -75,7 +76,6 @@ class GaussianDeform(ImageAugmentor):
self.randrange = self.shape[0] / 8
else:
self.randrange = randrange
self._init()
def _augment(self, img):
if img.coords:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment