Commit a1da74af authored by Yuxin Wu's avatar Yuxin Wu

more docs about augmentor (#996)

parent 7a0b15d5
......@@ -123,7 +123,7 @@ _C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectro
_C.TRAIN.EVAL_PERIOD = 25 # period (epochs) to run eva
# preprocessing --------------------
# Alternative old (worse & faster) setting: 600, 1024
# Alternative old (worse & faster) setting: 600
_C.PREPROC.TRAIN_SHORT_EDGE_SIZE = [800, 800] # [min, max] to sample from
_C.PREPROC.TEST_SHORT_EDGE_SIZE = 800
_C.PREPROC.MAX_SIZE = 1333
......
......@@ -205,7 +205,7 @@ def get_imagenet_tfdata(datadir, name, batch_size, mapper=None, parallel=None):
def fbresnet_mapper(isTrain):
"""
Note: compared to fbresnet_augmentor, it
lacks some photometric augmentation that may have a small effect on accuracy.
lacks some photometric augmentation that may have a small effect (0.1~0.2%) on accuracy.
"""
JPEG_OPT = {'fancy_upscaling': True, 'dct_method': 'INTEGER_ACCURATE'}
......
......@@ -15,9 +15,8 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates',
def check_dtype(img):
assert isinstance(img, np.ndarray), "[Augmentor] Needs an numpy array, but got a {}!".format(type(img))
if isinstance(img.dtype, np.integer):
assert img.dtype == np.uint8, \
"[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype)
assert not isinstance(img.dtype, np.integer) or (img.dtype == np.uint8), \
"[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype)
def validate_coords(coords):
......@@ -161,9 +160,9 @@ class AugmentImageCoordinates(MapData):
validate_coords(coords)
if self._copy:
img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img)
img, prms = self.augs.augment_return_params(img)
dp[self._img_index] = img
coords = self.augs._augment_coords(coords, prms)
coords = self.augs.augment_coords(coords, prms)
dp[self._coords_index] = coords
return dp
......@@ -207,15 +206,15 @@ class AugmentImageComponents(MapData):
major_image = index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image])
check_dtype(im)
im, prms = self.augs._augment_return_params(im)
im, prms = self.augs.augment_return_params(im)
dp[major_image] = im
for idx in index[1:]:
check_dtype(dp[idx])
dp[idx] = self.augs._augment(copy_func(dp[idx]), prms)
dp[idx] = self.augs.augment_with_params(copy_func(dp[idx]), prms)
for idx in coords_index:
coords = copy_func(dp[idx])
validate_coords(coords)
dp[idx] = self.augs._augment_coords(coords, prms)
dp[idx] = self.augs.augment_coords(coords, prms)
return dp
super(AugmentImageComponents, self).__init__(ds, func)
......
......@@ -35,15 +35,22 @@ class Augmentor(object):
def augment(self, d):
"""
Perform augmentation on the data.
Returns:
augmented data
"""
d, params = self._augment_return_params(d)
return d
def augment_return_params(self, d):
"""
Augment the data and return the augmentation parameters.
The returned parameters can be used to augment another data with identical transformation.
This can be used in, e.g. augmentation for image, masks, keypoints altogether.
Returns:
augmented data
augmentation params
augmentation params: can be any type
"""
return self._augment_return_params(d)
......@@ -54,6 +61,15 @@ class Augmentor(object):
prms = self._get_augment_params(d)
return (self._augment(d, prms), prms)
def augment_with_params(self, d, param):
"""
Augment the data with the given param.
Returns:
augmented data
"""
return self._augment(d, param)
@abstractmethod
def _augment(self, d, param):
"""
......@@ -115,8 +131,9 @@ class ImageAugmentor(Augmentor):
def augment_coords(self, coords, param):
"""
Augment the coordinates given the param.
By default, an augmentor keeps coordinates unchanged.
If a subclass changes coordinates but couldn't implement this method,
If a subclass of :class:`ImageAugmentor` changes coordinates but couldn't implement this method,
it should ``raise NotImplementedError()``.
Args:
......@@ -132,7 +149,7 @@ class ImageAugmentor(Augmentor):
class AugmentorList(ImageAugmentor):
"""
Augment by a list of augmentors
Augment an image by a list of augmentors
"""
def __init__(self, augmentors):
......
......@@ -378,7 +378,9 @@ class MultiThreadPrefetchData(DataFlow):
def __init__(self, get_df, nr_prefetch, nr_thread):
"""
Args:
get_df ( -> DataFlow): a callable which returns a DataFlow
get_df ( -> DataFlow): a callable which returns a DataFlow.
Each thread will call this function to get the DataFlow to use.
Therefore do not return the same DataFlow for each call.
nr_prefetch (int): size of the queue
nr_thread (int): number of threads
"""
......
......@@ -11,7 +11,6 @@ import zmq
from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard
from .common import RepeatedData
from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils import logger
from ..utils.serialize import loads, dumps
from .parallel import (
......@@ -59,10 +58,9 @@ class _ParallelMapData(ProxyDataFlow):
dp = next(self._iter)
self._send(dp)
except StopIteration:
logger.error(
raise RuntimeError(
"[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True!".format(
type(self).__name__))
raise
self._buffer_occupancy += cnt
def get_data_non_strict(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment