Commit a6745419 authored by Yuxin Wu's avatar Yuxin Wu

make augmentors copy=True by default. Also add shallow copy to MapDataComponents (#203, #207)

parent 20338134
...@@ -23,8 +23,9 @@ function library. Tensopack trainers integrate these two components and add more ...@@ -23,8 +23,9 @@ function library. Tensopack trainers integrate these two components and add more
details such as multi-GPU training. At the same time it keeps the power of customization to you details such as multi-GPU training. At the same time it keeps the power of customization to you
through callbacks. through callbacks.
* :doc:`callback` are like ``tf.train.SessionRunHook`` plugins, or extensions. During training, * Callbacks are like ``tf.train.SessionRunHook``, or plugins, or extensions. During training,
everything you want to do other than the main iterations can be defined through callbacks. everything you want to do other than the main iterations can be defined through callbacks.
See :doc:`callback` for some examples what you can do.
User Tutorials User Tutorials
======================== ========================
......
...@@ -132,7 +132,7 @@ def get_data(name): ...@@ -132,7 +132,7 @@ def get_data(name):
# the original image shape (321x481) in BSDS is not a multiple of 16 # the original image shape (321x481) in BSDS is not a multiple of 16
IMAGE_SHAPE = (320, 480) IMAGE_SHAPE = (320, 480)
shape_aug = [imgaug.CenterCrop(IMAGE_SHAPE)] shape_aug = [imgaug.CenterCrop(IMAGE_SHAPE)]
ds = AugmentImageComponents(ds, shape_aug, (0, 1)) ds = AugmentImageComponents(ds, shape_aug, (0, 1), copy=False)
def f(m): # thresholding def f(m): # thresholding
m[m >= 0.50] = 1 m[m >= 0.50] = 1
...@@ -145,7 +145,7 @@ def get_data(name): ...@@ -145,7 +145,7 @@ def get_data(name):
imgaug.Brightness(63, clip=False), imgaug.Brightness(63, clip=False),
imgaug.Contrast((0.4, 1.5)), imgaug.Contrast((0.4, 1.5)),
] ]
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchDataByShape(ds, 8, idx=0) ds = BatchDataByShape(ds, 8, idx=0)
ds = PrefetchDataZMQ(ds, 1) ds = PrefetchDataZMQ(ds, 1)
else: else:
......
...@@ -147,7 +147,7 @@ def get_data(train_or_test): ...@@ -147,7 +147,7 @@ def get_data(train_or_test):
imgaug.MapImage(lambda x: x - pp_mean), imgaug.MapImage(lambda x: x - pp_mean),
imgaug.CenterCrop((224, 224)), imgaug.CenterCrop((224, 224)),
] ]
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, 6) ds = PrefetchDataZMQ(ds, 6)
......
...@@ -252,7 +252,7 @@ def get_data(train_or_test): ...@@ -252,7 +252,7 @@ def get_data(train_or_test):
imgaug.CenterCrop((299, 299)), imgaug.CenterCrop((299, 299)),
imgaug.MapImage(lambda x: x - pp_mean_299), imgaug.MapImage(lambda x: x - pp_mean_299),
] ]
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, min(12, multiprocessing.cpu_count())) ds = PrefetchDataZMQ(ds, min(12, multiprocessing.cpu_count()))
......
...@@ -180,7 +180,7 @@ def get_data(train_or_test): ...@@ -180,7 +180,7 @@ def get_data(train_or_test):
imgaug.CenterCrop((224, 224)), imgaug.CenterCrop((224, 224)),
imgaug.ToUint8() imgaug.ToUint8()
] ]
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors, copy=False)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count())) ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from copy import copy
from termcolor import colored from termcolor import colored
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
...@@ -209,6 +210,9 @@ class MapData(ProxyDataFlow): ...@@ -209,6 +210,9 @@ class MapData(ProxyDataFlow):
func (datapoint -> datapoint | None): takes a datapoint and returns a new func (datapoint -> datapoint | None): takes a datapoint and returns a new
datapoint. Return None to discard this data point. datapoint. Return None to discard this data point.
Note that if you use the filter feature, ``ds.size()`` will be incorrect. Note that if you use the filter feature, ``ds.size()`` will be incorrect.
Note:
Be careful if func modifies datapoints.
""" """
super(MapData, self).__init__(ds) super(MapData, self).__init__(ds)
self.func = func self.func = func
...@@ -230,11 +234,16 @@ class MapDataComponent(MapData): ...@@ -230,11 +234,16 @@ class MapDataComponent(MapData):
return None to discard this datapoint. return None to discard this datapoint.
Note that if you use the filter feature, ``ds.size()`` will be incorrect. Note that if you use the filter feature, ``ds.size()`` will be incorrect.
index (int): index of the component. index (int): index of the component.
Note:
This proxy itself doesn't modify the datapoints. But be careful because func
may modify the components.
""" """
def f(dp): def f(dp):
r = func(dp[index]) r = func(dp[index])
if r is None: if r is None:
return None return None
dp = copy(dp) # avoid modifying the list
dp[index] = r dp[index] = r
return dp return dp
super(MapDataComponent, self).__init__(ds, f) super(MapDataComponent, self).__init__(ds, f)
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os
import copy
import glob import glob
import cv2 import cv2
import numpy as np import numpy as np
...@@ -86,7 +85,7 @@ class BSDS500(RNGDataFlow): ...@@ -86,7 +85,7 @@ class BSDS500(RNGDataFlow):
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
yield [copy.copy(self.data[k]), self.label[k]] yield [self.data[k], self.label[k]]
try: try:
......
...@@ -9,7 +9,6 @@ import pickle ...@@ -9,7 +9,6 @@ import pickle
import numpy as np import numpy as np
import six import six
from six.moves import range from six.moves import range
import copy
from ...utils import logger from ...utils import logger
from ...utils.fs import download, get_dataset_path from ...utils.fs import download, get_dataset_path
...@@ -109,7 +108,7 @@ class CifarBase(RNGDataFlow): ...@@ -109,7 +108,7 @@ class CifarBase(RNGDataFlow):
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
# since cifar is quite small, just do it for safety # since cifar is quite small, just do it for safety
yield copy.deepcopy(self.data[k]) yield self.data[k]
def get_per_pixel_mean(self): def get_per_pixel_mean(self):
""" """
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import os import os
import numpy as np import numpy as np
import copy
from ...utils import logger from ...utils import logger
from ...utils.fs import get_dataset_path from ...utils.fs import get_dataset_path
...@@ -58,7 +57,7 @@ class SVHNDigit(RNGDataFlow): ...@@ -58,7 +57,7 @@ class SVHNDigit(RNGDataFlow):
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
# since svhn is quite small, just do it for safety # since svhn is quite small, just do it for safety
yield [copy.copy(self.X[k]), self.Y[k]] yield [self.X[k], self.Y[k]]
@staticmethod @staticmethod
def get_per_pixel_mean(): def get_per_pixel_mean():
......
...@@ -50,15 +50,16 @@ class AugmentImageComponent(MapDataComponent): ...@@ -50,15 +50,16 @@ class AugmentImageComponent(MapDataComponent):
""" """
Apply image augmentors on 1 component. Apply image augmentors on 1 component.
""" """
def __init__(self, ds, augmentors, index=0, copy=False): def __init__(self, ds, augmentors, index=0, copy=True):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented. index (int): the index of the image component to be augmented.
copy (bool): make a copy of input images so the original data copy (bool): Some augmentors modify the input images. When copy is
won't be modified. Turn it on when it's dangerous to True, a copy will be made before any augmentors are applied,
modify the input (e.g. when inputs are persistent in memory). to keep the original images not modified.
Turn it off to save time when you know it's OK.
""" """
if isinstance(augmentors, AugmentorList): if isinstance(augmentors, AugmentorList):
self.augs = augmentors self.augs = augmentors
...@@ -94,29 +95,30 @@ class AugmentImageComponents(MapData): ...@@ -94,29 +95,30 @@ class AugmentImageComponents(MapData):
Apply image augmentors on several components, with shared augmentation parameters. Apply image augmentors on several components, with shared augmentation parameters.
""" """
def __init__(self, ds, augmentors, index=(0, 1), copy=False): def __init__(self, ds, augmentors, index=(0, 1), copy=True):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of components. index: tuple of indices of components.
copy (bool): make a copy of input images so the original data copy (bool): Some augmentors modify the input images. When copy is
won't be modified. Turn it on when it's dangerous to True, a copy will be made before any augmentors are applied,
modify the input (e.g. when inputs are persistent in memory). to keep the original images not modified.
Turn it off to save time when you know it's OK.
""" """
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self.ds = ds self.ds = ds
self._nr_error = 0 self._nr_error = 0
def func(dp): def func(dp):
dp = copy_mod.copy(dp) # always do a shallow copy, make sure the list is intact
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
dp = copy_func(dp)
try: try:
im = dp[index[0]] im = copy_func(dp[index[0]])
im, prms = self.augs._augment_return_params(im) im, prms = self.augs._augment_return_params(im)
dp[index[0]] = im dp[index[0]] = im
for idx in index[1:]: for idx in index[1:]:
dp[idx] = self.augs._augment(dp[idx], prms) dp[idx] = self.augs._augment(copy_func(dp[idx]), prms)
return dp return dp
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment