Commit a6745419 authored by Yuxin Wu's avatar Yuxin Wu

make augmentors copy=True by default. Also add shallow copy to MapDataComponents (#203, #207)

parent 20338134
......@@ -23,8 +23,9 @@ function library. Tensopack trainers integrate these two components and add more
details such as multi-GPU training. At the same time it keeps the power of customization to you
through callbacks.
* :doc:`callback` are like ``tf.train.SessionRunHook`` plugins, or extensions. During training,
* Callbacks are like ``tf.train.SessionRunHook``, or plugins, or extensions. During training,
everything you want to do other than the main iterations can be defined through callbacks.
See :doc:`callback` for some examples what you can do.
User Tutorials
========================
......
......@@ -132,7 +132,7 @@ def get_data(name):
# the original image shape (321x481) in BSDS is not a multiple of 16
IMAGE_SHAPE = (320, 480)
shape_aug = [imgaug.CenterCrop(IMAGE_SHAPE)]
ds = AugmentImageComponents(ds, shape_aug, (0, 1))
ds = AugmentImageComponents(ds, shape_aug, (0, 1), copy=False)
def f(m): # thresholding
m[m >= 0.50] = 1
......@@ -145,7 +145,7 @@ def get_data(name):
imgaug.Brightness(63, clip=False),
imgaug.Contrast((0.4, 1.5)),
]
ds = AugmentImageComponent(ds, augmentors)
ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchDataByShape(ds, 8, idx=0)
ds = PrefetchDataZMQ(ds, 1)
else:
......
......@@ -147,7 +147,7 @@ def get_data(train_or_test):
imgaug.MapImage(lambda x: x - pp_mean),
imgaug.CenterCrop((224, 224)),
]
ds = AugmentImageComponent(ds, augmentors)
ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain:
ds = PrefetchDataZMQ(ds, 6)
......
......@@ -252,7 +252,7 @@ def get_data(train_or_test):
imgaug.CenterCrop((299, 299)),
imgaug.MapImage(lambda x: x - pp_mean_299),
]
ds = AugmentImageComponent(ds, augmentors)
ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain:
ds = PrefetchDataZMQ(ds, min(12, multiprocessing.cpu_count()))
......
......@@ -180,7 +180,7 @@ def get_data(train_or_test):
imgaug.CenterCrop((224, 224)),
imgaug.ToUint8()
]
ds = AugmentImageComponent(ds, augmentors)
ds = AugmentImageComponent(ds, augmentors, copy=False)
if isTrain:
ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
......
......@@ -4,6 +4,7 @@
from __future__ import division
import numpy as np
from copy import copy
from termcolor import colored
from collections import deque, defaultdict
from six.moves import range, map
......@@ -209,6 +210,9 @@ class MapData(ProxyDataFlow):
func (datapoint -> datapoint | None): takes a datapoint and returns a new
datapoint. Return None to discard this data point.
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
Note:
Be careful if func modifies datapoints.
"""
super(MapData, self).__init__(ds)
self.func = func
......@@ -230,11 +234,16 @@ class MapDataComponent(MapData):
return None to discard this datapoint.
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
index (int): index of the component.
Note:
This proxy itself doesn't modify the datapoints. But be careful because func
may modify the components.
"""
def f(dp):
r = func(dp[index])
if r is None:
return None
dp = copy(dp) # avoid modifying the list
dp[index] = r
return dp
super(MapDataComponent, self).__init__(ds, f)
......
......@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import copy
import glob
import cv2
import numpy as np
......@@ -86,7 +85,7 @@ class BSDS500(RNGDataFlow):
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
yield [copy.copy(self.data[k]), self.label[k]]
yield [self.data[k], self.label[k]]
try:
......
......@@ -9,7 +9,6 @@ import pickle
import numpy as np
import six
from six.moves import range
import copy
from ...utils import logger
from ...utils.fs import download, get_dataset_path
......@@ -109,7 +108,7 @@ class CifarBase(RNGDataFlow):
self.rng.shuffle(idxs)
for k in idxs:
# since cifar is quite small, just do it for safety
yield copy.deepcopy(self.data[k])
yield self.data[k]
def get_per_pixel_mean(self):
"""
......
......@@ -5,7 +5,6 @@
import os
import numpy as np
import copy
from ...utils import logger
from ...utils.fs import get_dataset_path
......@@ -58,7 +57,7 @@ class SVHNDigit(RNGDataFlow):
self.rng.shuffle(idxs)
for k in idxs:
# since svhn is quite small, just do it for safety
yield [copy.copy(self.X[k]), self.Y[k]]
yield [self.X[k], self.Y[k]]
@staticmethod
def get_per_pixel_mean():
......
......@@ -50,15 +50,16 @@ class AugmentImageComponent(MapDataComponent):
"""
Apply image augmentors on 1 component.
"""
def __init__(self, ds, augmentors, index=0, copy=False):
def __init__(self, ds, augmentors, index=0, copy=True):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented.
copy (bool): make a copy of input images so the original data
won't be modified. Turn it on when it's dangerous to
modify the input (e.g. when inputs are persistent in memory).
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
Turn it off to save time when you know it's OK.
"""
if isinstance(augmentors, AugmentorList):
self.augs = augmentors
......@@ -94,29 +95,30 @@ class AugmentImageComponents(MapData):
Apply image augmentors on several components, with shared augmentation parameters.
"""
def __init__(self, ds, augmentors, index=(0, 1), copy=False):
def __init__(self, ds, augmentors, index=(0, 1), copy=True):
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of components.
copy (bool): make a copy of input images so the original data
won't be modified. Turn it on when it's dangerous to
modify the input (e.g. when inputs are persistent in memory).
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
Turn it off to save time when you know it's OK.
"""
self.augs = AugmentorList(augmentors)
self.ds = ds
self._nr_error = 0
def func(dp):
dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
dp = copy_func(dp)
try:
im = dp[index[0]]
im = copy_func(dp[index[0]])
im, prms = self.augs._augment_return_params(im)
dp[index[0]] = im
for idx in index[1:]:
dp[idx] = self.augs._augment(dp[idx], prms)
dp[idx] = self.augs._augment(copy_func(dp[idx]), prms)
return dp
except KeyboardInterrupt:
raise
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment