Commit 2a60316c authored by Yuxin Wu's avatar Yuxin Wu

augmentimage as a dataflow

parent 8759e324
......@@ -4,14 +4,11 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import copy
from .base import DataFlow, ProxyDataFlow
from .imgaug import AugmentorList, Image
from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData',
'AugmentImageComponent']
'MapDataComponent', 'RandomChooseData' ]
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
......@@ -184,18 +181,3 @@ class RandomChooseData(DataFlow):
yield next(itr)
except StopIteration:
return
def AugmentImageComponent(ds, augmentors, index=0):
"""
Augment the image in each data point
Args:
ds: a DataFlow dataset instance
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
"""
# TODO reset rng at the beginning of each get_data
aug = AugmentorList(augmentors)
return MapDataComponent(
ds,
lambda img: aug.augment(Image(img)).arr,
index)
......@@ -5,9 +5,11 @@
import numpy as np
import cv2
from .base import DataFlow
import copy
from .base import DataFlow, ProxyDataFlow
from .imgaug import AugmentorList, Image
__all__ = ['ImageFromFile']
__all__ = ['ImageFromFile', 'AugmentImageComponent']
class ImageFromFile(DataFlow):
""" generate rgb images from files """
......@@ -34,3 +36,26 @@ class ImageFromFile(DataFlow):
im = cv2.resize(im, self.resize[::-1])
yield [im]
class AugmentImageComponent(ProxyDataFlow):
"""
Augment the image in each data point
Args:
ds: a DataFlow dataset instance
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
"""
def __init__(self, ds, augmentors, index=0):
super(AugmentImageComponent, self).__init__(ds)
self.augs = AugmentorList(augmentors)
self.index = index
def reset_state(self):
self.ds.reset_state()
# TODO aug reset
def get_data(self):
for dp in self.ds.get_data():
dp = copy.deepcopy(dp)
dp[self.index] = self.augs.augment(Image(dp[self.index])).arr
yield dp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment