Commit a1da74af authored by Yuxin Wu's avatar Yuxin Wu

more docs about augmentor (#996)

parent 7a0b15d5
...@@ -123,7 +123,7 @@ _C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectro ...@@ -123,7 +123,7 @@ _C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectro
_C.TRAIN.EVAL_PERIOD = 25 # period (epochs) to run eva _C.TRAIN.EVAL_PERIOD = 25 # period (epochs) to run eva
# preprocessing -------------------- # preprocessing --------------------
# Alternative old (worse & faster) setting: 600, 1024 # Alternative old (worse & faster) setting: 600
_C.PREPROC.TRAIN_SHORT_EDGE_SIZE = [800, 800] # [min, max] to sample from _C.PREPROC.TRAIN_SHORT_EDGE_SIZE = [800, 800] # [min, max] to sample from
_C.PREPROC.TEST_SHORT_EDGE_SIZE = 800 _C.PREPROC.TEST_SHORT_EDGE_SIZE = 800
_C.PREPROC.MAX_SIZE = 1333 _C.PREPROC.MAX_SIZE = 1333
......
...@@ -205,7 +205,7 @@ def get_imagenet_tfdata(datadir, name, batch_size, mapper=None, parallel=None): ...@@ -205,7 +205,7 @@ def get_imagenet_tfdata(datadir, name, batch_size, mapper=None, parallel=None):
def fbresnet_mapper(isTrain): def fbresnet_mapper(isTrain):
""" """
Note: compared to fbresnet_augmentor, it Note: compared to fbresnet_augmentor, it
lacks some photometric augmentation that may have a small effect on accuracy. lacks some photometric augmentation that may have a small effect (0.1~0.2%) on accuracy.
""" """
JPEG_OPT = {'fancy_upscaling': True, 'dct_method': 'INTEGER_ACCURATE'} JPEG_OPT = {'fancy_upscaling': True, 'dct_method': 'INTEGER_ACCURATE'}
......
...@@ -15,9 +15,8 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', ...@@ -15,9 +15,8 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates',
def check_dtype(img): def check_dtype(img):
assert isinstance(img, np.ndarray), "[Augmentor] Needs an numpy array, but got a {}!".format(type(img)) assert isinstance(img, np.ndarray), "[Augmentor] Needs an numpy array, but got a {}!".format(type(img))
if isinstance(img.dtype, np.integer): assert not isinstance(img.dtype, np.integer) or (img.dtype == np.uint8), \
assert img.dtype == np.uint8, \ "[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype)
"[Augmentor] Got image of type {}, use uint8 or floating points instead!".format(img.dtype)
def validate_coords(coords): def validate_coords(coords):
...@@ -161,9 +160,9 @@ class AugmentImageCoordinates(MapData): ...@@ -161,9 +160,9 @@ class AugmentImageCoordinates(MapData):
validate_coords(coords) validate_coords(coords)
if self._copy: if self._copy:
img, coords = copy_mod.deepcopy((img, coords)) img, coords = copy_mod.deepcopy((img, coords))
img, prms = self.augs._augment_return_params(img) img, prms = self.augs.augment_return_params(img)
dp[self._img_index] = img dp[self._img_index] = img
coords = self.augs._augment_coords(coords, prms) coords = self.augs.augment_coords(coords, prms)
dp[self._coords_index] = coords dp[self._coords_index] = coords
return dp return dp
...@@ -207,15 +206,15 @@ class AugmentImageComponents(MapData): ...@@ -207,15 +206,15 @@ class AugmentImageComponents(MapData):
major_image = index[0] # image to be used to get params. TODO better design? major_image = index[0] # image to be used to get params. TODO better design?
im = copy_func(dp[major_image]) im = copy_func(dp[major_image])
check_dtype(im) check_dtype(im)
im, prms = self.augs._augment_return_params(im) im, prms = self.augs.augment_return_params(im)
dp[major_image] = im dp[major_image] = im
for idx in index[1:]: for idx in index[1:]:
check_dtype(dp[idx]) check_dtype(dp[idx])
dp[idx] = self.augs._augment(copy_func(dp[idx]), prms) dp[idx] = self.augs.augment_with_params(copy_func(dp[idx]), prms)
for idx in coords_index: for idx in coords_index:
coords = copy_func(dp[idx]) coords = copy_func(dp[idx])
validate_coords(coords) validate_coords(coords)
dp[idx] = self.augs._augment_coords(coords, prms) dp[idx] = self.augs.augment_coords(coords, prms)
return dp return dp
super(AugmentImageComponents, self).__init__(ds, func) super(AugmentImageComponents, self).__init__(ds, func)
......
...@@ -35,15 +35,22 @@ class Augmentor(object): ...@@ -35,15 +35,22 @@ class Augmentor(object):
def augment(self, d): def augment(self, d):
""" """
Perform augmentation on the data. Perform augmentation on the data.
Returns:
augmented data
""" """
d, params = self._augment_return_params(d) d, params = self._augment_return_params(d)
return d return d
def augment_return_params(self, d): def augment_return_params(self, d):
""" """
Augment the data and return the augmentation parameters.
The returned parameters can be used to augment another data with identical transformation.
This can be used in, e.g. augmentation for image, masks, keypoints altogether.
Returns: Returns:
augmented data augmented data
augmentation params augmentation params: can be any type
""" """
return self._augment_return_params(d) return self._augment_return_params(d)
...@@ -54,6 +61,15 @@ class Augmentor(object): ...@@ -54,6 +61,15 @@ class Augmentor(object):
prms = self._get_augment_params(d) prms = self._get_augment_params(d)
return (self._augment(d, prms), prms) return (self._augment(d, prms), prms)
def augment_with_params(self, d, param):
"""
Augment the data with the given param.
Returns:
augmented data
"""
return self._augment(d, param)
@abstractmethod @abstractmethod
def _augment(self, d, param): def _augment(self, d, param):
""" """
...@@ -115,8 +131,9 @@ class ImageAugmentor(Augmentor): ...@@ -115,8 +131,9 @@ class ImageAugmentor(Augmentor):
def augment_coords(self, coords, param): def augment_coords(self, coords, param):
""" """
Augment the coordinates given the param. Augment the coordinates given the param.
By default, an augmentor keeps coordinates unchanged. By default, an augmentor keeps coordinates unchanged.
If a subclass changes coordinates but couldn't implement this method, If a subclass of :class:`ImageAugmentor` changes coordinates but couldn't implement this method,
it should ``raise NotImplementedError()``. it should ``raise NotImplementedError()``.
Args: Args:
...@@ -132,7 +149,7 @@ class ImageAugmentor(Augmentor): ...@@ -132,7 +149,7 @@ class ImageAugmentor(Augmentor):
class AugmentorList(ImageAugmentor): class AugmentorList(ImageAugmentor):
""" """
Augment by a list of augmentors Augment an image by a list of augmentors
""" """
def __init__(self, augmentors): def __init__(self, augmentors):
......
...@@ -378,7 +378,9 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -378,7 +378,9 @@ class MultiThreadPrefetchData(DataFlow):
def __init__(self, get_df, nr_prefetch, nr_thread): def __init__(self, get_df, nr_prefetch, nr_thread):
""" """
Args: Args:
get_df ( -> DataFlow): a callable which returns a DataFlow get_df ( -> DataFlow): a callable which returns a DataFlow.
Each thread will call this function to get the DataFlow to use.
Therefore do not return the same DataFlow for each call.
nr_prefetch (int): size of the queue nr_prefetch (int): size of the queue
nr_thread (int): number of threads nr_thread (int): number of threads
""" """
......
...@@ -11,7 +11,6 @@ import zmq ...@@ -11,7 +11,6 @@ import zmq
from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard
from .common import RepeatedData from .common import RepeatedData
from ..utils.concurrency import StoppableThread, enable_death_signal from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils import logger
from ..utils.serialize import loads, dumps from ..utils.serialize import loads, dumps
from .parallel import ( from .parallel import (
...@@ -59,10 +58,9 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -59,10 +58,9 @@ class _ParallelMapData(ProxyDataFlow):
dp = next(self._iter) dp = next(self._iter)
self._send(dp) self._send(dp)
except StopIteration: except StopIteration:
logger.error( raise RuntimeError(
"[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True!".format( "[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True!".format(
type(self).__name__)) type(self).__name__))
raise
self._buffer_occupancy += cnt self._buffer_occupancy += cnt
def get_data_non_strict(self): def get_data_non_strict(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment