Commit 4f7c4682 authored by Yuxin Wu's avatar Yuxin Wu

some more docs

parent a6f88814
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from pkgutil import walk_packages from pkgutil import walk_packages
import os import os
def global_import(name): def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst: for k in lst:
...@@ -14,5 +14,5 @@ def global_import(name): ...@@ -14,5 +14,5 @@ def global_import(name):
for _, module_name, _ in walk_packages( for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [os.path.dirname(__file__)]):
if not module_name.startswith('_'): if not module_name.startswith('_'):
global_import(module_name) _global_import(module_name)
...@@ -73,7 +73,7 @@ class StatPrinter(Callback): ...@@ -73,7 +73,7 @@ class StatPrinter(Callback):
""" """
def __init__(self, print_tag=None): def __init__(self, print_tag=None):
""" """
:param print_tag : a list of regex to match scalar summary to print. :param print_tag: a list of regex to match scalar summary to print.
If None, will print all scalar tags If None, will print all scalar tags
""" """
self.print_tag = print_tag self.print_tag = print_tag
......
...@@ -9,7 +9,7 @@ import os.path ...@@ -9,7 +9,7 @@ import os.path
from . import dataset from . import dataset
from . import imgaug from . import imgaug
def global_import(name): def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst: for k in lst:
...@@ -20,5 +20,5 @@ for _, module_name, _ in walk_packages( ...@@ -20,5 +20,5 @@ for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [os.path.dirname(__file__)]):
if not module_name.startswith('_') and \ if not module_name.startswith('_') and \
module_name not in __SKIP: module_name not in __SKIP:
global_import(module_name) _global_import(module_name)
...@@ -9,12 +9,13 @@ from abc import abstractmethod, ABCMeta ...@@ -9,12 +9,13 @@ from abc import abstractmethod, ABCMeta
__all__ = ['DataFlow', 'ProxyDataFlow'] __all__ = ['DataFlow', 'ProxyDataFlow']
class DataFlow(object): class DataFlow(object):
""" Base class for all DataFlow """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
@abstractmethod @abstractmethod
def get_data(self): def get_data(self):
""" """
A generator to generate data as tuple. A generator to generate data as a list.
""" """
def size(self): def size(self):
...@@ -30,10 +31,17 @@ class DataFlow(object): ...@@ -30,10 +31,17 @@ class DataFlow(object):
pass pass
class ProxyDataFlow(DataFlow): class ProxyDataFlow(DataFlow):
""" Base class for DataFlow that proxies another"""
def __init__(self, ds): def __init__(self, ds):
"""
:param ds: a :mod:`DataFlow` instance to proxy
"""
self.ds = ds self.ds = ds
def reset_state(self): def reset_state(self):
"""
Will reset state of the proxied DataFlow
"""
self.ds.reset_state() self.ds.reset_state()
def size(self): def size(self):
......
...@@ -14,10 +14,12 @@ __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', ...@@ -14,10 +14,12 @@ __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
""" """
Group data in ds into batches Group data in `ds` into batches.
ds: a DataFlow instance
remainder: whether to return the remaining data smaller than a batch_size. :param ds: a DataFlow instance
if set True, will possibly return a data point of a smaller 1st dimension :param remainder: whether to return the remaining data smaller than a batch_size.
If set True, will possibly return a data point of a smaller 1st dimension.
Otherwise, all generated data are guranteed to have the same size.
""" """
super(BatchData, self).__init__(ds) super(BatchData, self).__init__(ds)
if not remainder: if not remainder:
...@@ -34,17 +36,20 @@ class BatchData(ProxyDataFlow): ...@@ -34,17 +36,20 @@ class BatchData(ProxyDataFlow):
return div + int(self.remainder) return div + int(self.remainder)
def get_data(self): def get_data(self):
"""
:returns: produce batched data by tiling data on an extra 0th dimension.
"""
holder = [] holder = []
for data in self.ds.get_data(): for data in self.ds.get_data():
holder.append(data) holder.append(data)
if len(holder) == self.batch_size: if len(holder) == self.batch_size:
yield BatchData.aggregate_batch(holder) yield BatchData._aggregate_batch(holder)
holder = [] holder = []
if self.remainder and len(holder) > 0: if self.remainder and len(holder) > 0:
yield BatchData.aggregate_batch(holder) yield BatchData._aggregate_batch(holder)
@staticmethod @staticmethod
def aggregate_batch(data_holder): def _aggregate_batch(data_holder):
size = len(data_holder[0]) size = len(data_holder[0])
result = [] result = []
for k in range(size): for k in range(size):
...@@ -60,16 +65,23 @@ class BatchData(ProxyDataFlow): ...@@ -60,16 +65,23 @@ class BatchData(ProxyDataFlow):
return result return result
class FixedSizeData(ProxyDataFlow): class FixedSizeData(ProxyDataFlow):
""" generate data from another dataflow, but with a fixed epoch size""" """ Generate data from another DataFlow, but with a fixed epoch size"""
def __init__(self, ds, size): def __init__(self, ds, size):
"""
:param ds: a :mod:`DataFlow` to produce data
:param size: a int
"""
super(FixedSizeData, self).__init__(ds) super(FixedSizeData, self).__init__(ds)
self._size = size self._size = int(size)
self.itr = None self.itr = None
def size(self): def size(self):
return self._size return self._size
def get_data(self): def get_data(self):
"""
Produce data from ds, stop at size
"""
if self.itr is None: if self.itr is None:
self.itr = self.ds.get_data() self.itr = self.ds.get_data()
cnt = 0 cnt = 0
...@@ -86,10 +98,15 @@ class FixedSizeData(ProxyDataFlow): ...@@ -86,10 +98,15 @@ class FixedSizeData(ProxyDataFlow):
return return
class RepeatedData(ProxyDataFlow): class RepeatedData(ProxyDataFlow):
""" repeat another dataflow for certain times """ Take data points from another `DataFlow` and produce them until
if nr == -1, repeat infinitely many times it's exhausted for certain amount of times.
""" """
def __init__(self, ds, nr): def __init__(self, ds, nr):
"""
:param ds: a :mod:`DataFlow` instance.
:param nr: number of times to repeat ds.
If nr == -1, repeat ds infinitely many times.
"""
self.nr = nr self.nr = nr
super(RepeatedData, self).__init__(ds) super(RepeatedData, self).__init__(ds)
...@@ -109,13 +126,14 @@ class RepeatedData(ProxyDataFlow): ...@@ -109,13 +126,14 @@ class RepeatedData(ProxyDataFlow):
yield dp yield dp
class FakeData(DataFlow): class FakeData(DataFlow):
""" Build fake random data of given shapes""" """ Generate fake random data of given shapes"""
def __init__(self, shapes, size): def __init__(self, shapes, size):
""" """
shapes: list of list/tuple :param shapes: a list of lists/tuples
:param size: size of this DataFlow
""" """
self.shapes = shapes self.shapes = shapes
self._size = size self._size = int(size)
self.rng = get_rng(self) self.rng = get_rng(self)
def size(self): def size(self):
...@@ -126,8 +144,12 @@ class FakeData(DataFlow): ...@@ -126,8 +144,12 @@ class FakeData(DataFlow):
yield [self.rng.random_sample(k) for k in self.shapes] yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(ProxyDataFlow): class MapData(ProxyDataFlow):
""" Map a function to the datapoint""" """ Map a function on the datapoint"""
def __init__(self, ds, func): def __init__(self, ds, func):
"""
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a original datapoint, returns a new datapoint
"""
super(MapData, self).__init__(ds) super(MapData, self).__init__(ds)
self.func = func self.func = func
...@@ -138,6 +160,10 @@ class MapData(ProxyDataFlow): ...@@ -138,6 +160,10 @@ class MapData(ProxyDataFlow):
class MapDataComponent(ProxyDataFlow): class MapDataComponent(ProxyDataFlow):
""" Apply a function to the given index in the datapoint""" """ Apply a function to the given index in the datapoint"""
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
"""
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint dp[index], returns a new value of dp[index]
"""
super(MapDataComponent, self).__init__(ds) super(MapDataComponent, self).__init__(ds)
self.func = func self.func = func
self.index = index self.index = index
...@@ -150,11 +176,12 @@ class MapDataComponent(ProxyDataFlow): ...@@ -150,11 +176,12 @@ class MapDataComponent(ProxyDataFlow):
class RandomChooseData(DataFlow): class RandomChooseData(DataFlow):
""" """
Randomly choose from several dataflow. Stop producing when any of its dataflow stops. Randomly choose from several DataFlow. Stop producing when any of them is
exhausted.
""" """
def __init__(self, df_lists): def __init__(self, df_lists):
""" """
df_lists: list of dataflow, or list of (dataflow, probability) tuple :param df_lists: list of dataflow, or list of (dataflow, probability) tuple
""" """
if isinstance(df_lists[0], (tuple, list)): if isinstance(df_lists[0], (tuple, list)):
assert sum([v[1] for v in df_lists]) == 1.0 assert sum([v[1] for v in df_lists]) == 1.0
...@@ -184,13 +211,12 @@ class RandomChooseData(DataFlow): ...@@ -184,13 +211,12 @@ class RandomChooseData(DataFlow):
class RandomMixData(DataFlow): class RandomMixData(DataFlow):
""" """
Randomly choose from several dataflow, will eventually exhaust all dataflow. Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
So it's a perfect mix.
""" """
def __init__(self, df_lists): def __init__(self, df_lists):
""" """
df_lists: list of dataflow :param df_lists: list of dataflow.
all DataFlow in df_lists must have size() implemented All DataFlow in `df_lists` must have :func:`size()` implemented
""" """
self.df_lists = df_lists self.df_lists = df_lists
self.sizes = [k.size() for k in self.df_lists] self.sizes = [k.size() for k in self.df_lists]
...@@ -217,11 +243,11 @@ class RandomMixData(DataFlow): ...@@ -217,11 +243,11 @@ class RandomMixData(DataFlow):
class JoinData(DataFlow): class JoinData(DataFlow):
""" """
Concatenate several dataflows Concatenate several dataflows.
""" """
def __init__(self, df_lists): def __init__(self, df_lists):
""" """
df_lists: list of dataflow :param df_lists: list of :mod:`DataFlow` instances
""" """
self.df_lists = df_lists self.df_lists = df_lists
......
...@@ -9,8 +9,12 @@ from ..utils.fs import mkdir_p ...@@ -9,8 +9,12 @@ from ..utils.fs import mkdir_p
# TODO name_func to write label? # TODO name_func to write label?
def dump_dataset_images(ds, dirname, max_count=None, index=0): def dump_dataset_images(ds, dirname, max_count=None, index=0):
""" dump images to a folder """ Dump images from a `DataFlow` to a directory.
index: the index of the image in a data point
:param ds: a `DataFlow` instance.
:param dirname: name of the directory.
:param max_count: max number of images to dump
:param index: the index of the image component in a data point.
""" """
mkdir_p(dirname) mkdir_p(dirname)
if max_count is None: if max_count is None:
......
...@@ -22,9 +22,14 @@ __all__ = ['HDF5Data'] ...@@ -22,9 +22,14 @@ __all__ = ['HDF5Data']
class HDF5Data(DataFlow): class HDF5Data(DataFlow):
""" """
Zip data from different paths in this HDF5 data file Zip data from different paths in an HDF5 file. Will load all data into memory.
""" """
def __init__(self, filename, data_paths, shuffle=True): def __init__(self, filename, data_paths, shuffle=True):
"""
:param filename: h5 data file.
:param data_paths: list of h5 paths to zipped. For example ['images', 'labels']
:param shuffle: shuffle the order of all data.
"""
self.f = h5py.File(filename, 'r') self.f = h5py.File(filename, 'r')
logger.info("Loading {} to memory...".format(filename)) logger.info("Loading {} to memory...".format(filename))
self.dps = [self.f[k].value for k in data_paths] self.dps = [self.f[k].value for k in data_paths]
......
...@@ -12,11 +12,11 @@ from .imgaug import AugmentorList, Image ...@@ -12,11 +12,11 @@ from .imgaug import AugmentorList, Image
__all__ = ['ImageFromFile', 'AugmentImageComponent'] __all__ = ['ImageFromFile', 'AugmentImageComponent']
class ImageFromFile(DataFlow): class ImageFromFile(DataFlow):
""" generate rgb images from files """ """ Generate rgb images from list of files """
def __init__(self, files, channel=3, resize=None): def __init__(self, files, channel=3, resize=None):
""" files: list of file path """ :param files: list of file paths
channel: 1 or 3 channel :param channel: 1 or 3 channel
resize: a (h, w) tuple. If given, will force a resize :param resize: a (h, w) tuple. If given, will force a resize
""" """
assert len(files) assert len(files)
self.files = files self.files = files
...@@ -39,13 +39,14 @@ class ImageFromFile(DataFlow): ...@@ -39,13 +39,14 @@ class ImageFromFile(DataFlow):
class AugmentImageComponent(MapDataComponent): class AugmentImageComponent(MapDataComponent):
""" """
Augment image in each data point Augment the image component of datapoints
Args:
ds: a DataFlow dataset instance
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
""" """
def __init__(self, ds, augmentors, index=0): def __init__(self, ds, augmentors, index=0):
"""
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: the index of the image component in the produced datapoints by `ds`. default to be 0
"""
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
super(AugmentImageComponent, self).__init__( super(AugmentImageComponent, self).__init__(
ds, lambda x: self.augs.augment(Image(x)).arr, index) ds, lambda x: self.augs.augment(Image(x)).arr, index)
......
...@@ -8,14 +8,20 @@ from ...utils import get_rng ...@@ -8,14 +8,20 @@ from ...utils import get_rng
__all__ = ['Image', 'ImageAugmentor', 'AugmentorList'] __all__ = ['Image', 'ImageAugmentor', 'AugmentorList']
class Image(object): class Image(object):
""" An image with attributes, for augmentor to operate on """ An image class with attributes, for augmentor to operate on.
Attributes (such as coordinates) have to be augmented acoordingly, if necessary Attributes (such as coordinates) have to be augmented acoordingly by
the augmentor, if necessary.
""" """
def __init__(self, arr, coords=None): def __init__(self, arr, coords=None):
"""
:param arr: the image array. Expected to be of [h, w, c] or [h, w]
:param coords: keypoint coordinates.
"""
self.arr = arr self.arr = arr
self.coords = coords self.coords = coords
class ImageAugmentor(object): class ImageAugmentor(object):
""" Base class for an image augmentor"""
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
def __init__(self): def __init__(self):
...@@ -33,19 +39,17 @@ class ImageAugmentor(object): ...@@ -33,19 +39,17 @@ class ImageAugmentor(object):
def augment(self, img): def augment(self, img):
""" """
Note: will both modify `img` in-place and return `img` Perform augmentation on the image in-place.
:param img: an `Image` instance.
:returns: the augmented `Image` instance. arr will always be of type
'float32' after augmentation.
""" """
self._augment(img) self._augment(img)
return img return img
@abstractmethod @abstractmethod
def _augment(self, img): def _augment(self, img):
""" pass
Augment the image in-place. Will always make it float32 array.
Args:
img: the input Image instance
img.arr must be of shape [h, w] or [h, w, c]
"""
def _rand_range(self, low=1.0, high=None, size=None): def _rand_range(self, low=1.0, high=None, size=None):
if high is None: if high is None:
...@@ -59,6 +63,9 @@ class AugmentorList(ImageAugmentor): ...@@ -59,6 +63,9 @@ class AugmentorList(ImageAugmentor):
Augment by a list of augmentors Augment by a list of augmentors
""" """
def __init__(self, augmentors): def __init__(self, augmentors):
"""
:param augmentors: list of `ImageAugmentor` instance to be applied
"""
self.augs = augmentors self.augs = augmentors
def _augment(self, img): def _augment(self, img):
...@@ -68,5 +75,6 @@ class AugmentorList(ImageAugmentor): ...@@ -68,5 +75,6 @@ class AugmentorList(ImageAugmentor):
aug.augment(img) aug.augment(img)
def reset_state(self): def reset_state(self):
""" Will reset state of each augmentor """
for a in self.augs: for a in self.augs:
a.reset_state() a.reset_state()
...@@ -14,8 +14,7 @@ class RandomCrop(ImageAugmentor): ...@@ -14,8 +14,7 @@ class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """ """ Randomly crop the image into a smaller one """
def __init__(self, crop_shape): def __init__(self, crop_shape):
""" """
Args: :param crop_shape: a shape like (h, w)
crop_shape: shape in (h, w)
""" """
self._init(locals()) self._init(locals())
...@@ -28,8 +27,11 @@ class RandomCrop(ImageAugmentor): ...@@ -28,8 +27,11 @@ class RandomCrop(ImageAugmentor):
raise NotImplementedError() raise NotImplementedError()
class CenterCrop(ImageAugmentor): class CenterCrop(ImageAugmentor):
""" Crop the image in the center""" """ Crop the image at the center"""
def __init__(self, crop_shape): def __init__(self, crop_shape):
"""
:param crop_shape: a shape like (h, w)
"""
self._init(locals()) self._init(locals())
def _augment(self, img): def _augment(self, img):
...@@ -43,6 +45,12 @@ class CenterCrop(ImageAugmentor): ...@@ -43,6 +45,12 @@ class CenterCrop(ImageAugmentor):
class FixedCrop(ImageAugmentor): class FixedCrop(ImageAugmentor):
""" Crop a rectangle at a given location""" """ Crop a rectangle at a given location"""
def __init__(self, rangex, rangey): def __init__(self, rangex, rangey):
"""
Two arguments defined the range in both axes to crop, min inclued, max excluded.
:param rangex: like (xmin, xmax).
:param rangey: like (ymin, ymax).
"""
self._init(locals()) self._init(locals())
def _augment(self, img): def _augment(self, img):
...@@ -53,17 +61,30 @@ class FixedCrop(ImageAugmentor): ...@@ -53,17 +61,30 @@ class FixedCrop(ImageAugmentor):
raise NotImplementedError() raise NotImplementedError()
class BackgroundFiller(object): class BackgroundFiller(object):
@abstractmethod """ Base class for all BackgroundFiller"""
def fill(background_shape, img): def fill(self, background_shape, img):
""" """
return a proper background image of background_shape, given img Return a proper background image of background_shape, given img
:param background_shape: a shape of [h, w]
:param img: an image
:returns: a background image
""" """
return self._fill(background_shape, img)
@abstractmethod
def _fill(self, background_shape, img):
pass
class ConstantBackgroundFiller(BackgroundFiller): class ConstantBackgroundFiller(BackgroundFiller):
""" Fill the background by a constant """
def __init__(self, value): def __init__(self, value):
"""
:param value: the value to fill the background.
"""
self.value = value self.value = value
def fill(self, background_shape, img): def _fill(self, background_shape, img):
assert img.ndim in [3, 1] assert img.ndim in [3, 1]
if img.ndim == 3: if img.ndim == 3:
return_shape = background_shape + (3,) return_shape = background_shape + (3,)
...@@ -73,9 +94,13 @@ class ConstantBackgroundFiller(BackgroundFiller): ...@@ -73,9 +94,13 @@ class ConstantBackgroundFiller(BackgroundFiller):
class CenterPaste(ImageAugmentor): class CenterPaste(ImageAugmentor):
""" """
Paste the image onto center of a background Paste the image onto the center of a background canvas.
""" """
def __init__(self, background_shape, background_filler=None): def __init__(self, background_shape, background_filler=None):
"""
:param background_shape: shape of the background canvas.
:param background_filler: a `BackgroundFiller` instance. Default to zero-filler.
"""
if background_filler is None: if background_filler is None:
background_filler = ConstantBackgroundFiller(0) background_filler = ConstantBackgroundFiller(0)
......
...@@ -10,6 +10,7 @@ __all__ = ['GaussianDeform', 'GaussianMap'] ...@@ -10,6 +10,7 @@ __all__ = ['GaussianDeform', 'GaussianMap']
# TODO really needs speedup # TODO really needs speedup
class GaussianMap(object): class GaussianMap(object):
""" Generate gaussian weighted deformation map"""
def __init__(self, image_shape, sigma=0.5): def __init__(self, image_shape, sigma=0.5):
assert len(image_shape) == 2 assert len(image_shape) == 2
self.shape = image_shape self.shape = image_shape
...@@ -53,14 +54,14 @@ def np_sample(img, coords): ...@@ -53,14 +54,14 @@ def np_sample(img, coords):
# TODO input/output with different shape # TODO input/output with different shape
class GaussianDeform(ImageAugmentor): class GaussianDeform(ImageAugmentor):
""" """
Some kind of deformation Some kind of deformation. Quite slow.
""" """
#TODO docs
def __init__(self, anchors, shape, sigma=0.5, randrange=None): def __init__(self, anchors, shape, sigma=0.5, randrange=None):
""" """
anchors: in [0,1] coordinate :param anchors: in [0,1] coordinate
shape: 2D image shape :param shape: image shape in [h, w]
randrange: default to shape[0] / 8 :param sigma: sigma for Gaussian weight
:param randrange: default to shape[0] / 8
""" """
super(GaussianDeform, self).__init__() super(GaussianDeform, self).__init__()
self.anchors = anchors self.anchors = anchors
......
...@@ -9,9 +9,12 @@ __all__ = ['BrightnessAdd', 'Contrast', 'MeanVarianceNormalize'] ...@@ -9,9 +9,12 @@ __all__ = ['BrightnessAdd', 'Contrast', 'MeanVarianceNormalize']
class BrightnessAdd(ImageAugmentor): class BrightnessAdd(ImageAugmentor):
""" """
Randomly add a value within [-delta,delta], and clip in [0,255] Random adjust brightness.
""" """
def __init__(self, delta, clip=True): def __init__(self, delta, clip=True):
"""
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
"""
assert delta > 0 assert delta > 0
self._init(locals()) self._init(locals())
...@@ -27,6 +30,10 @@ class Contrast(ImageAugmentor): ...@@ -27,6 +30,10 @@ class Contrast(ImageAugmentor):
and clip to [0, 255] and clip to [0, 255]
""" """
def __init__(self, factor_range, clip=True): def __init__(self, factor_range, clip=True):
"""
:param factor_range: an interval to random sample the `contrast_factor`.
:param clip: boolean.
"""
self._init(locals()) self._init(locals())
def _augment(self, img): def _augment(self, img):
...@@ -44,6 +51,9 @@ class MeanVarianceNormalize(ImageAugmentor): ...@@ -44,6 +51,9 @@ class MeanVarianceNormalize(ImageAugmentor):
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels)) where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
""" """
def __init__(self, all_channel=True): def __init__(self, all_channel=True):
"""
:param all_channel: if True, normalize all channels together. else separately.
"""
self.all_channel = all_channel self.all_channel = all_channel
def _augment(self, img): def _augment(self, img):
......
...@@ -9,11 +9,16 @@ import cv2 ...@@ -9,11 +9,16 @@ import cv2
__all__ = ['Flip', 'MapImage', 'Resize'] __all__ = ['Flip', 'MapImage', 'Resize']
class Flip(ImageAugmentor): class Flip(ImageAugmentor):
"""
Random flip.
"""
def __init__(self, horiz=False, vert=False, prob=0.5): def __init__(self, horiz=False, vert=False, prob=0.5):
""" """
Random flip. Only one of horiz, vert can be set.
Args:
horiz, vert: True/False :param horiz: whether or not apply horizontal flip.
:param vert: whether or not apply vertical flip.
:param prob: probability of flip.
""" """
if horiz and vert: if horiz and vert:
raise ValueError("Please use two Flip, with both 0.5 prob") raise ValueError("Please use two Flip, with both 0.5 prob")
...@@ -34,7 +39,13 @@ class Flip(ImageAugmentor): ...@@ -34,7 +39,13 @@ class Flip(ImageAugmentor):
class MapImage(ImageAugmentor): class MapImage(ImageAugmentor):
"""
Map the image array by a function.
"""
def __init__(self, func): def __init__(self, func):
"""
:param func: a function which takes a image array and return a augmented one
"""
self.func = func self.func = func
def _augment(self, img): def _augment(self, img):
...@@ -42,11 +53,10 @@ class MapImage(ImageAugmentor): ...@@ -42,11 +53,10 @@ class MapImage(ImageAugmentor):
class Resize(ImageAugmentor): class Resize(ImageAugmentor):
"""Resize image to a target size""" """ Resize image to a target size"""
def __init__(self, shape): def __init__(self, shape):
""" """
Args: :param shape: shape in (h, w)
shape: (h, w)
""" """
self._init(locals()) self._init(locals())
......
...@@ -15,8 +15,8 @@ class Sentinel: ...@@ -15,8 +15,8 @@ class Sentinel:
class PrefetchProcess(multiprocessing.Process): class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue): def __init__(self, ds, queue):
""" """
ds: ds to take data from :param ds: ds to take data from
queue: output queue to put results in :param queue: output queue to put results in
""" """
super(PrefetchProcess, self).__init__() super(PrefetchProcess, self).__init__()
self.ds = ds self.ds = ds
...@@ -30,11 +30,15 @@ class PrefetchProcess(multiprocessing.Process): ...@@ -30,11 +30,15 @@ class PrefetchProcess(multiprocessing.Process):
finally: finally:
self.queue.put(Sentinel()) self.queue.put(Sentinel())
class PrefetchData(DataFlow): class PrefetchData(DataFlow):
"""
Prefetch data from a `DataFlow` using multiprocessing
"""
def __init__(self, ds, nr_prefetch, nr_proc=1): def __init__(self, ds, nr_prefetch, nr_proc=1):
""" """
use multiprocess :param ds: a `DataFlow` instance.
:param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_proc: number of processes to use.
""" """
self.ds = ds self.ds = ds
self._size = self.ds.size() self._size = self.ds.size()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment