Commit 4f7c4682 authored by Yuxin Wu's avatar Yuxin Wu

some more docs

parent a6f88814
......@@ -5,7 +5,7 @@
from pkgutil import walk_packages
import os
def global_import(name):
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
......@@ -14,5 +14,5 @@ def global_import(name):
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
_global_import(module_name)
......@@ -73,7 +73,7 @@ class StatPrinter(Callback):
"""
def __init__(self, print_tag=None):
"""
:param print_tag : a list of regex to match scalar summary to print.
:param print_tag: a list of regex to match scalar summary to print.
If None, will print all scalar tags
"""
self.print_tag = print_tag
......
......@@ -9,7 +9,7 @@ import os.path
from . import dataset
from . import imgaug
def global_import(name):
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
......@@ -20,5 +20,5 @@ for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_') and \
module_name not in __SKIP:
global_import(module_name)
_global_import(module_name)
......@@ -9,12 +9,13 @@ from abc import abstractmethod, ABCMeta
__all__ = ['DataFlow', 'ProxyDataFlow']
class DataFlow(object):
""" Base class for all DataFlow """
__metaclass__ = ABCMeta
@abstractmethod
def get_data(self):
"""
A generator to generate data as tuple.
A generator to generate data as a list.
"""
def size(self):
......@@ -30,10 +31,17 @@ class DataFlow(object):
pass
class ProxyDataFlow(DataFlow):
""" Base class for DataFlow that proxies another"""
def __init__(self, ds):
"""
:param ds: a :mod:`DataFlow` instance to proxy
"""
self.ds = ds
def reset_state(self):
"""
Will reset state of the proxied DataFlow
"""
self.ds.reset_state()
def size(self):
......
......@@ -14,10 +14,12 @@ __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
"""
Group data in ds into batches
ds: a DataFlow instance
remainder: whether to return the remaining data smaller than a batch_size.
if set True, will possibly return a data point of a smaller 1st dimension
Group data in `ds` into batches.
:param ds: a DataFlow instance
:param remainder: whether to return the remaining data smaller than a batch_size.
If set True, will possibly return a data point of a smaller 1st dimension.
Otherwise, all generated data are guranteed to have the same size.
"""
super(BatchData, self).__init__(ds)
if not remainder:
......@@ -34,17 +36,20 @@ class BatchData(ProxyDataFlow):
return div + int(self.remainder)
def get_data(self):
"""
:returns: produce batched data by tiling data on an extra 0th dimension.
"""
holder = []
for data in self.ds.get_data():
holder.append(data)
if len(holder) == self.batch_size:
yield BatchData.aggregate_batch(holder)
yield BatchData._aggregate_batch(holder)
holder = []
if self.remainder and len(holder) > 0:
yield BatchData.aggregate_batch(holder)
yield BatchData._aggregate_batch(holder)
@staticmethod
def aggregate_batch(data_holder):
def _aggregate_batch(data_holder):
size = len(data_holder[0])
result = []
for k in range(size):
......@@ -60,16 +65,23 @@ class BatchData(ProxyDataFlow):
return result
class FixedSizeData(ProxyDataFlow):
""" generate data from another dataflow, but with a fixed epoch size"""
""" Generate data from another DataFlow, but with a fixed epoch size"""
def __init__(self, ds, size):
"""
:param ds: a :mod:`DataFlow` to produce data
:param size: a int
"""
super(FixedSizeData, self).__init__(ds)
self._size = size
self._size = int(size)
self.itr = None
def size(self):
return self._size
def get_data(self):
"""
Produce data from ds, stop at size
"""
if self.itr is None:
self.itr = self.ds.get_data()
cnt = 0
......@@ -86,10 +98,15 @@ class FixedSizeData(ProxyDataFlow):
return
class RepeatedData(ProxyDataFlow):
""" repeat another dataflow for certain times
if nr == -1, repeat infinitely many times
""" Take data points from another `DataFlow` and produce them until
it's exhausted for certain amount of times.
"""
def __init__(self, ds, nr):
"""
:param ds: a :mod:`DataFlow` instance.
:param nr: number of times to repeat ds.
If nr == -1, repeat ds infinitely many times.
"""
self.nr = nr
super(RepeatedData, self).__init__(ds)
......@@ -109,13 +126,14 @@ class RepeatedData(ProxyDataFlow):
yield dp
class FakeData(DataFlow):
""" Build fake random data of given shapes"""
""" Generate fake random data of given shapes"""
def __init__(self, shapes, size):
"""
shapes: list of list/tuple
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
"""
self.shapes = shapes
self._size = size
self._size = int(size)
self.rng = get_rng(self)
def size(self):
......@@ -126,8 +144,12 @@ class FakeData(DataFlow):
yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(ProxyDataFlow):
""" Map a function to the datapoint"""
""" Map a function on the datapoint"""
def __init__(self, ds, func):
"""
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a original datapoint, returns a new datapoint
"""
super(MapData, self).__init__(ds)
self.func = func
......@@ -138,6 +160,10 @@ class MapData(ProxyDataFlow):
class MapDataComponent(ProxyDataFlow):
""" Apply a function to the given index in the datapoint"""
def __init__(self, ds, func, index=0):
"""
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint dp[index], returns a new value of dp[index]
"""
super(MapDataComponent, self).__init__(ds)
self.func = func
self.index = index
......@@ -150,11 +176,12 @@ class MapDataComponent(ProxyDataFlow):
class RandomChooseData(DataFlow):
"""
Randomly choose from several dataflow. Stop producing when any of its dataflow stops.
Randomly choose from several DataFlow. Stop producing when any of them is
exhausted.
"""
def __init__(self, df_lists):
"""
df_lists: list of dataflow, or list of (dataflow, probability) tuple
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple
"""
if isinstance(df_lists[0], (tuple, list)):
assert sum([v[1] for v in df_lists]) == 1.0
......@@ -184,13 +211,12 @@ class RandomChooseData(DataFlow):
class RandomMixData(DataFlow):
"""
Randomly choose from several dataflow, will eventually exhaust all dataflow.
So it's a perfect mix.
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
"""
def __init__(self, df_lists):
"""
df_lists: list of dataflow
all DataFlow in df_lists must have size() implemented
:param df_lists: list of dataflow.
All DataFlow in `df_lists` must have :func:`size()` implemented
"""
self.df_lists = df_lists
self.sizes = [k.size() for k in self.df_lists]
......@@ -217,11 +243,11 @@ class RandomMixData(DataFlow):
class JoinData(DataFlow):
"""
Concatenate several dataflows
Concatenate several dataflows.
"""
def __init__(self, df_lists):
"""
df_lists: list of dataflow
:param df_lists: list of :mod:`DataFlow` instances
"""
self.df_lists = df_lists
......
......@@ -9,8 +9,12 @@ from ..utils.fs import mkdir_p
# TODO name_func to write label?
def dump_dataset_images(ds, dirname, max_count=None, index=0):
""" dump images to a folder
index: the index of the image in a data point
""" Dump images from a `DataFlow` to a directory.
:param ds: a `DataFlow` instance.
:param dirname: name of the directory.
:param max_count: max number of images to dump
:param index: the index of the image component in a data point.
"""
mkdir_p(dirname)
if max_count is None:
......
......@@ -22,9 +22,14 @@ __all__ = ['HDF5Data']
class HDF5Data(DataFlow):
"""
Zip data from different paths in this HDF5 data file
Zip data from different paths in an HDF5 file. Will load all data into memory.
"""
def __init__(self, filename, data_paths, shuffle=True):
"""
:param filename: h5 data file.
:param data_paths: list of h5 paths to zipped. For example ['images', 'labels']
:param shuffle: shuffle the order of all data.
"""
self.f = h5py.File(filename, 'r')
logger.info("Loading {} to memory...".format(filename))
self.dps = [self.f[k].value for k in data_paths]
......
......@@ -12,11 +12,11 @@ from .imgaug import AugmentorList, Image
__all__ = ['ImageFromFile', 'AugmentImageComponent']
class ImageFromFile(DataFlow):
""" generate rgb images from files """
""" Generate rgb images from list of files """
def __init__(self, files, channel=3, resize=None):
""" files: list of file path
channel: 1 or 3 channel
resize: a (h, w) tuple. If given, will force a resize
""" :param files: list of file paths
:param channel: 1 or 3 channel
:param resize: a (h, w) tuple. If given, will force a resize
"""
assert len(files)
self.files = files
......@@ -39,13 +39,14 @@ class ImageFromFile(DataFlow):
class AugmentImageComponent(MapDataComponent):
"""
Augment image in each data point
Args:
ds: a DataFlow dataset instance
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
Augment the image component of datapoints
"""
def __init__(self, ds, augmentors, index=0):
"""
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: the index of the image component in the produced datapoints by `ds`. default to be 0
"""
self.augs = AugmentorList(augmentors)
super(AugmentImageComponent, self).__init__(
ds, lambda x: self.augs.augment(Image(x)).arr, index)
......
......@@ -8,14 +8,20 @@ from ...utils import get_rng
__all__ = ['Image', 'ImageAugmentor', 'AugmentorList']
class Image(object):
""" An image with attributes, for augmentor to operate on
Attributes (such as coordinates) have to be augmented acoordingly, if necessary
""" An image class with attributes, for augmentor to operate on.
Attributes (such as coordinates) have to be augmented acoordingly by
the augmentor, if necessary.
"""
def __init__(self, arr, coords=None):
"""
:param arr: the image array. Expected to be of [h, w, c] or [h, w]
:param coords: keypoint coordinates.
"""
self.arr = arr
self.coords = coords
class ImageAugmentor(object):
""" Base class for an image augmentor"""
__metaclass__ = ABCMeta
def __init__(self):
......@@ -33,19 +39,17 @@ class ImageAugmentor(object):
def augment(self, img):
"""
Note: will both modify `img` in-place and return `img`
Perform augmentation on the image in-place.
:param img: an `Image` instance.
:returns: the augmented `Image` instance. arr will always be of type
'float32' after augmentation.
"""
self._augment(img)
return img
@abstractmethod
def _augment(self, img):
"""
Augment the image in-place. Will always make it float32 array.
Args:
img: the input Image instance
img.arr must be of shape [h, w] or [h, w, c]
"""
pass
def _rand_range(self, low=1.0, high=None, size=None):
if high is None:
......@@ -59,6 +63,9 @@ class AugmentorList(ImageAugmentor):
Augment by a list of augmentors
"""
def __init__(self, augmentors):
"""
:param augmentors: list of `ImageAugmentor` instance to be applied
"""
self.augs = augmentors
def _augment(self, img):
......@@ -68,5 +75,6 @@ class AugmentorList(ImageAugmentor):
aug.augment(img)
def reset_state(self):
""" Will reset state of each augmentor """
for a in self.augs:
a.reset_state()
......@@ -14,8 +14,7 @@ class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """
def __init__(self, crop_shape):
"""
Args:
crop_shape: shape in (h, w)
:param crop_shape: a shape like (h, w)
"""
self._init(locals())
......@@ -28,8 +27,11 @@ class RandomCrop(ImageAugmentor):
raise NotImplementedError()
class CenterCrop(ImageAugmentor):
""" Crop the image in the center"""
""" Crop the image at the center"""
def __init__(self, crop_shape):
"""
:param crop_shape: a shape like (h, w)
"""
self._init(locals())
def _augment(self, img):
......@@ -43,6 +45,12 @@ class CenterCrop(ImageAugmentor):
class FixedCrop(ImageAugmentor):
""" Crop a rectangle at a given location"""
def __init__(self, rangex, rangey):
"""
Two arguments defined the range in both axes to crop, min inclued, max excluded.
:param rangex: like (xmin, xmax).
:param rangey: like (ymin, ymax).
"""
self._init(locals())
def _augment(self, img):
......@@ -53,17 +61,30 @@ class FixedCrop(ImageAugmentor):
raise NotImplementedError()
class BackgroundFiller(object):
@abstractmethod
def fill(background_shape, img):
""" Base class for all BackgroundFiller"""
def fill(self, background_shape, img):
"""
return a proper background image of background_shape, given img
Return a proper background image of background_shape, given img
:param background_shape: a shape of [h, w]
:param img: an image
:returns: a background image
"""
return self._fill(background_shape, img)
@abstractmethod
def _fill(self, background_shape, img):
pass
class ConstantBackgroundFiller(BackgroundFiller):
""" Fill the background by a constant """
def __init__(self, value):
"""
:param value: the value to fill the background.
"""
self.value = value
def fill(self, background_shape, img):
def _fill(self, background_shape, img):
assert img.ndim in [3, 1]
if img.ndim == 3:
return_shape = background_shape + (3,)
......@@ -73,9 +94,13 @@ class ConstantBackgroundFiller(BackgroundFiller):
class CenterPaste(ImageAugmentor):
"""
Paste the image onto center of a background
Paste the image onto the center of a background canvas.
"""
def __init__(self, background_shape, background_filler=None):
"""
:param background_shape: shape of the background canvas.
:param background_filler: a `BackgroundFiller` instance. Default to zero-filler.
"""
if background_filler is None:
background_filler = ConstantBackgroundFiller(0)
......
......@@ -10,6 +10,7 @@ __all__ = ['GaussianDeform', 'GaussianMap']
# TODO really needs speedup
class GaussianMap(object):
""" Generate gaussian weighted deformation map"""
def __init__(self, image_shape, sigma=0.5):
assert len(image_shape) == 2
self.shape = image_shape
......@@ -53,14 +54,14 @@ def np_sample(img, coords):
# TODO input/output with different shape
class GaussianDeform(ImageAugmentor):
"""
Some kind of deformation
Some kind of deformation. Quite slow.
"""
#TODO docs
def __init__(self, anchors, shape, sigma=0.5, randrange=None):
"""
anchors: in [0,1] coordinate
shape: 2D image shape
randrange: default to shape[0] / 8
:param anchors: in [0,1] coordinate
:param shape: image shape in [h, w]
:param sigma: sigma for Gaussian weight
:param randrange: default to shape[0] / 8
"""
super(GaussianDeform, self).__init__()
self.anchors = anchors
......
......@@ -9,9 +9,12 @@ __all__ = ['BrightnessAdd', 'Contrast', 'MeanVarianceNormalize']
class BrightnessAdd(ImageAugmentor):
"""
Randomly add a value within [-delta,delta], and clip in [0,255]
Random adjust brightness.
"""
def __init__(self, delta, clip=True):
"""
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
"""
assert delta > 0
self._init(locals())
......@@ -27,6 +30,10 @@ class Contrast(ImageAugmentor):
and clip to [0, 255]
"""
def __init__(self, factor_range, clip=True):
"""
:param factor_range: an interval to random sample the `contrast_factor`.
:param clip: boolean.
"""
self._init(locals())
def _augment(self, img):
......@@ -44,6 +51,9 @@ class MeanVarianceNormalize(ImageAugmentor):
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
"""
def __init__(self, all_channel=True):
"""
:param all_channel: if True, normalize all channels together. else separately.
"""
self.all_channel = all_channel
def _augment(self, img):
......
......@@ -9,11 +9,16 @@ import cv2
__all__ = ['Flip', 'MapImage', 'Resize']
class Flip(ImageAugmentor):
def __init__(self, horiz=False, vert=False, prob=0.5):
"""
Random flip.
Args:
horiz, vert: True/False
"""
def __init__(self, horiz=False, vert=False, prob=0.5):
"""
Only one of horiz, vert can be set.
:param horiz: whether or not apply horizontal flip.
:param vert: whether or not apply vertical flip.
:param prob: probability of flip.
"""
if horiz and vert:
raise ValueError("Please use two Flip, with both 0.5 prob")
......@@ -34,7 +39,13 @@ class Flip(ImageAugmentor):
class MapImage(ImageAugmentor):
"""
Map the image array by a function.
"""
def __init__(self, func):
"""
:param func: a function which takes a image array and return a augmented one
"""
self.func = func
def _augment(self, img):
......@@ -42,11 +53,10 @@ class MapImage(ImageAugmentor):
class Resize(ImageAugmentor):
"""Resize image to a target size"""
""" Resize image to a target size"""
def __init__(self, shape):
"""
Args:
shape: (h, w)
:param shape: shape in (h, w)
"""
self._init(locals())
......
......@@ -15,8 +15,8 @@ class Sentinel:
class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue):
"""
ds: ds to take data from
queue: output queue to put results in
:param ds: ds to take data from
:param queue: output queue to put results in
"""
super(PrefetchProcess, self).__init__()
self.ds = ds
......@@ -30,11 +30,15 @@ class PrefetchProcess(multiprocessing.Process):
finally:
self.queue.put(Sentinel())
class PrefetchData(DataFlow):
"""
Prefetch data from a `DataFlow` using multiprocessing
"""
def __init__(self, ds, nr_prefetch, nr_proc=1):
"""
use multiprocess
:param ds: a `DataFlow` instance.
:param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_proc: number of processes to use.
"""
self.ds = ds
self._size = self.ds.size()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment