Commit d8a647f2 authored by Yuxin Wu's avatar Yuxin Wu

reset state in augmentors

parent 2a60316c
...@@ -39,7 +39,7 @@ class ImageFromFile(DataFlow): ...@@ -39,7 +39,7 @@ class ImageFromFile(DataFlow):
class AugmentImageComponent(ProxyDataFlow): class AugmentImageComponent(ProxyDataFlow):
""" """
Augment the image in each data point Augment image in each data point
Args: Args:
ds: a DataFlow dataset instance ds: a DataFlow dataset instance
augmentors: a list of ImageAugmentor instance augmentors: a list of ImageAugmentor instance
...@@ -52,7 +52,7 @@ class AugmentImageComponent(ProxyDataFlow): ...@@ -52,7 +52,7 @@ class AugmentImageComponent(ProxyDataFlow):
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
# TODO aug reset self.augs.reset_state()
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
......
...@@ -20,15 +20,18 @@ class ImageAugmentor(object): ...@@ -20,15 +20,18 @@ class ImageAugmentor(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
def __init__(self): def __init__(self):
self.rng = get_rng(self) self.reset_state()
def _init(self, params=None): def _init(self, params=None):
self.rng = get_rng(self) self.reset_state()
if params: if params:
for k, v in params.iteritems(): for k, v in params.iteritems():
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
def reset_state(self):
self.rng = get_rng(self)
def augment(self, img): def augment(self, img):
""" """
Note: will both modify `img` in-place and return `img` Note: will both modify `img` in-place and return `img`
...@@ -64,3 +67,7 @@ class AugmentorList(ImageAugmentor): ...@@ -64,3 +67,7 @@ class AugmentorList(ImageAugmentor):
img.arr = img.arr.astype('float32') img.arr = img.arr.astype('float32')
for aug in self.augs: for aug in self.augs:
aug.augment(img) aug.augment(img)
def reset_state(self):
for a in self.augs:
a.reset_state()
...@@ -61,6 +61,7 @@ class GaussianDeform(ImageAugmentor): ...@@ -61,6 +61,7 @@ class GaussianDeform(ImageAugmentor):
shape: 2D image shape shape: 2D image shape
randrange: default to shape[0] / 8 randrange: default to shape[0] / 8
""" """
super(GaussianDeform, self).__init__()
self.anchors = anchors self.anchors = anchors
self.K = len(self.anchors) self.K = len(self.anchors)
self.shape = shape self.shape = shape
...@@ -75,7 +76,6 @@ class GaussianDeform(ImageAugmentor): ...@@ -75,7 +76,6 @@ class GaussianDeform(ImageAugmentor):
self.randrange = self.shape[0] / 8 self.randrange = self.shape[0] / 8
else: else:
self.randrange = randrange self.randrange = randrange
self._init()
def _augment(self, img): def _augment(self, img):
if img.coords: if img.coords:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment