Commit 2a60316c authored by Yuxin Wu's avatar Yuxin Wu

augmentimage as a dataflow

parent 8759e324
...@@ -4,14 +4,11 @@ ...@@ -4,14 +4,11 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
import copy
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow
from .imgaug import AugmentorList, Image
from ..utils import * from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData', 'MapDataComponent', 'RandomChooseData' ]
'AugmentImageComponent']
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
...@@ -184,18 +181,3 @@ class RandomChooseData(DataFlow): ...@@ -184,18 +181,3 @@ class RandomChooseData(DataFlow):
yield next(itr) yield next(itr)
except StopIteration: except StopIteration:
return return
def AugmentImageComponent(ds, augmentors, index=0):
"""
Augment the image in each data point
Args:
ds: a DataFlow dataset instance
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
"""
# TODO reset rng at the beginning of each get_data
aug = AugmentorList(augmentors)
return MapDataComponent(
ds,
lambda img: aug.augment(Image(img)).arr,
index)
...@@ -5,9 +5,11 @@ ...@@ -5,9 +5,11 @@
import numpy as np import numpy as np
import cv2 import cv2
from .base import DataFlow import copy
from .base import DataFlow, ProxyDataFlow
from .imgaug import AugmentorList, Image
__all__ = ['ImageFromFile'] __all__ = ['ImageFromFile', 'AugmentImageComponent']
class ImageFromFile(DataFlow): class ImageFromFile(DataFlow):
""" generate rgb images from files """ """ generate rgb images from files """
...@@ -34,3 +36,26 @@ class ImageFromFile(DataFlow): ...@@ -34,3 +36,26 @@ class ImageFromFile(DataFlow):
im = cv2.resize(im, self.resize[::-1]) im = cv2.resize(im, self.resize[::-1])
yield [im] yield [im]
class AugmentImageComponent(ProxyDataFlow):
"""
Augment the image in each data point
Args:
ds: a DataFlow dataset instance
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
"""
def __init__(self, ds, augmentors, index=0):
super(AugmentImageComponent, self).__init__(ds)
self.augs = AugmentorList(augmentors)
self.index = index
def reset_state(self):
self.ds.reset_state()
# TODO aug reset
def get_data(self):
for dp in self.ds.get_data():
dp = copy.deepcopy(dp)
dp[self.index] = self.augs.augment(Image(dp[self.index])).arr
yield dp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment