Commit d0423fb6 authored by Yuxin Wu's avatar Yuxin Wu

add an 'copy' option to augmentors (#203)

parent 10c6f81d
......@@ -221,7 +221,7 @@ def get_data(dataset_name):
imgaug.CenterCrop((224, 224)),
imgaug.MapImage(lambda x: x - pp_mean_224),
]
ds = AugmentImageComponent(ds, augmentors)
ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain:
ds = PrefetchDataZMQ(ds, min(12, multiprocessing.cpu_count()))
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import copy
import glob
import cv2
import numpy as np
......@@ -85,7 +86,7 @@ class BSDS500(RNGDataFlow):
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
yield [self.data[k], self.label[k]]
yield [copy.copy(self.data[k]), self.label[k]]
try:
......
......@@ -108,7 +108,8 @@ class CifarBase(RNGDataFlow):
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
yield copy.copy(self.data[k])
# since cifar is quite small, just do it for safety
yield copy.deepcopy(self.data[k])
def get_per_pixel_mean(self):
"""
......
......@@ -5,6 +5,7 @@
import os
import numpy as np
import copy
from ...utils import logger
from ...utils.fs import get_dataset_path
......@@ -56,7 +57,8 @@ class SVHNDigit(RNGDataFlow):
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
yield [self.X[k], self.Y[k]]
# since svhn is quite small, just do it for safety
yield [copy.copy(self.X[k]), self.Y[k]]
@staticmethod
def get_per_pixel_mean():
......
......@@ -4,6 +4,7 @@
import numpy as np
import cv2
import copy as copy_mod
from .base import RNGDataFlow
from .common import MapDataComponent, MapData
from .imgaug import AugmentorList
......@@ -49,12 +50,15 @@ class AugmentImageComponent(MapDataComponent):
"""
Apply image augmentors on 1 component.
"""
def __init__(self, ds, augmentors, index=0):
def __init__(self, ds, augmentors, index=0, copy=False):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented.
copy (bool): make a copy of input images so the original data
won't be modified. Turn it on when it's dangerous to
modify the input (e.g. when inputs are persistent in memory).
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
......@@ -65,6 +69,8 @@ class AugmentImageComponent(MapDataComponent):
def func(x):
try:
if copy:
x = copy_mod.deepcopy(x)
ret = self.augs.augment(x)
except KeyboardInterrupt:
raise
......@@ -88,24 +94,28 @@ class AugmentImageComponents(MapData):
Apply image augmentors on several components, with shared augmentation parameters.
"""
def __init__(self, ds, augmentors, index=(0, 1)):
def __init__(self, ds, augmentors, index=(0, 1), copy=False):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of components.
copy (bool): make a copy of input images so the original data
won't be modified. Turn it on when it's dangerous to
modify the input (e.g. when inputs are persistent in memory).
"""
self.augs = AugmentorList(augmentors)
self.ds = ds
self._nr_error = 0
def func(dp):
copy_func = lambda x: x if copy else copy_mod.deepcopy # noqa
try:
im = dp[index[0]]
im = copy_func(dp[index[0]])
im, prms = self.augs._augment_return_params(im)
dp[index[0]] = im
for idx in index[1:]:
dp[idx] = self.augs._augment(dp[idx], prms)
dp[idx] = self.augs._augment(copy_func(dp[idx]), prms)
return dp
except KeyboardInterrupt:
raise
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment