Commit d8935ef3 authored by Yuxin Wu's avatar Yuxin Wu

update sphinx doc for dataflow/

parent edecca96
...@@ -69,9 +69,11 @@ extensions = [ ...@@ -69,9 +69,11 @@ extensions = [
#'sphinx.ext.coverage', #'sphinx.ext.coverage',
#'sphinx.ext.mathjax', #'sphinx.ext.mathjax',
'sphinx.ext.mathbase', 'sphinx.ext.mathbase',
'sphinx.ext.intersphinx',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
] ]
napoleon_google_docstring = True napoleon_google_docstring = True
napoleon_include_init_with_doc = True
napoleon_numpy_docstring = False napoleon_numpy_docstring = False
napoleon_use_rtype = False napoleon_use_rtype = False
...@@ -332,11 +334,9 @@ texinfo_documents = [ ...@@ -332,11 +334,9 @@ texinfo_documents = [
# If true, do not generate a @detailmenu in the "Top" node's menu. # If true, do not generate a @detailmenu in the "Top" node's menu.
#texinfo_no_detailmenu = False #texinfo_no_detailmenu = False
def skip(app, what, name, obj, skip, options):
# keep __init__ intersphinx_timeout = 0.1
if name == "__init__": intersphinx_mapping = {'python': ('https://docs.python.org/3.4', None)}
return False
return skip
def process_signature(app, what, name, obj, options, signature, def process_signature(app, what, name, obj, options, signature,
return_annotation): return_annotation):
...@@ -350,7 +350,6 @@ def process_signature(app, what, name, obj, options, signature, ...@@ -350,7 +350,6 @@ def process_signature(app, what, name, obj, options, signature,
def setup(app): def setup(app):
from recommonmark.transform import AutoStructify from recommonmark.transform import AutoStructify
app.connect('autodoc-process-signature', process_signature) app.connect('autodoc-process-signature', process_signature)
app.connect("autodoc-skip-member", skip)
app.add_config_value( app.add_config_value(
'recommonmark_config', 'recommonmark_config',
{'url_resolver': lambda url: \ {'url_resolver': lambda url: \
......
...@@ -15,37 +15,41 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow'] ...@@ -15,37 +15,41 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
class DataFlow(object): class DataFlow(object):
""" Base class for all DataFlow """ """ Base class for all DataFlow """
class Infinity:
pass
@abstractmethod @abstractmethod
def get_data(self): def get_data(self):
""" """
A generator to generate data as a list. The method to generate datapoints.
Datapoint should be a mutable list.
Each component should be assumed immutable. Yields:
list: The datapoint, i.e. list of components.
""" """
def size(self): def size(self):
""" """
Size of this data flow. Returns:
int: size of this data flow.
Raises:
:class:`NotImplementedError` if this DataFlow doesn't have a size.
""" """
raise NotImplementedError() raise NotImplementedError()
def reset_state(self): def reset_state(self):
""" """
Reset state of the dataflow. Will always be called before consuming data points. Reset state of the dataflow. It has to be called before producing datapoints.
for example, RNG **HAS** to be reset here if used in the DataFlow.
Otherwise it may not work well with prefetching, because different For example, RNG **has to** be reset if used in the DataFlow,
otherwise it won't work well with prefetching, because different
processes will have the same RNG state. processes will have the same RNG state.
""" """
pass pass
class RNGDataFlow(DataFlow): class RNGDataFlow(DataFlow):
""" A dataflow with rng""" """ A DataFlow with RNG"""
def reset_state(self): def reset_state(self):
""" Reset the RNG """
self.rng = get_rng(self) self.rng = get_rng(self)
...@@ -54,13 +58,14 @@ class ProxyDataFlow(DataFlow): ...@@ -54,13 +58,14 @@ class ProxyDataFlow(DataFlow):
def __init__(self, ds): def __init__(self, ds):
""" """
:param ds: a :mod:`DataFlow` instance to proxy Args:
ds (DataFlow): DataFlow to proxy.
""" """
self.ds = ds self.ds = ds
def reset_state(self): def reset_state(self):
""" """
Will reset state of the proxied DataFlow Reset state of the proxied DataFlow.
""" """
self.ds.reset_state() self.ds.reset_state()
......
This diff is collapsed.
...@@ -25,20 +25,20 @@ IMG_W, IMG_H = 481, 321 ...@@ -25,20 +25,20 @@ IMG_W, IMG_H = 481, 321
class BSDS500(RNGDataFlow): class BSDS500(RNGDataFlow):
""" """
`Berkeley Segmentation Data Set and Benchmarks 500 `Berkeley Segmentation Data Set and Benchmarks 500 dataset
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_. <http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
Produce (image, label) pair, where image has shape (321, 481, 3) and Produce ``(image, label)`` pair, where ``image`` has shape (321, 481, 3(BGR)) and
ranges in [0,255]. Label is binary and has shape (321, 481). ranges in [0,255].
Those pixels annotated as boundaries by <=2 annotators are set to 0. ``Label`` is a floating point image of shape (321, 481) in range [0, 1].
This is used in `Holistically-Nested Edge Detection The value of each pixel is ``number of times it is annotated as edge / total number of annotators for this image``.
<http://arxiv.org/abs/1504.06375>`_.
""" """
def __init__(self, name, data_dir=None, shuffle=True): def __init__(self, name, data_dir=None, shuffle=True):
""" """
:param name: 'train', 'test', 'val' Args:
:param data_dir: a directory containing the original 'BSR' directory. name (str): 'train', 'test', 'val'
data_dir (str): a directory containing the original 'BSR' directory.
""" """
# check and download data # check and download data
if data_dir is None: if data_dir is None:
......
...@@ -80,17 +80,7 @@ def get_filenames(dir, cifar_classnum): ...@@ -80,17 +80,7 @@ def get_filenames(dir, cifar_classnum):
class CifarBase(RNGDataFlow): class CifarBase(RNGDataFlow):
"""
Return [image, label],
image is 32x32x3 in the range [0,255]
"""
def __init__(self, train_or_test, shuffle=True, dir=None, cifar_classnum=10): def __init__(self, train_or_test, shuffle=True, dir=None, cifar_classnum=10):
"""
Args:
train_or_test: string either 'train' or 'test'
shuffle: default to True
"""
assert train_or_test in ['train', 'test'] assert train_or_test in ['train', 'test']
assert cifar_classnum == 10 or cifar_classnum == 100 assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum self.cifar_classnum = cifar_classnum
...@@ -139,13 +129,22 @@ class CifarBase(RNGDataFlow): ...@@ -139,13 +129,22 @@ class CifarBase(RNGDataFlow):
class Cifar10(CifarBase): class Cifar10(CifarBase):
"""
Produces [image, label] in Cifar10 dataset,
image is 32x32x3 in the range [0,255].
label is an int.
"""
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
"""
Args:
train_or_test (str): either 'train' or 'test'.
shuffle (bool): shuffle the dataset.
"""
super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10) super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10)
class Cifar100(CifarBase): class Cifar100(CifarBase):
""" Similar to Cifar10"""
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100) super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100)
......
...@@ -22,7 +22,7 @@ CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" ...@@ -22,7 +22,7 @@ CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class ILSVRCMeta(object): class ILSVRCMeta(object):
""" """
Some metadata for ILSVRC dataset. Provide methods to access metadata for ILSVRC dataset.
""" """
def __init__(self, dir=None): def __init__(self, dir=None):
...@@ -37,7 +37,8 @@ class ILSVRCMeta(object): ...@@ -37,7 +37,8 @@ class ILSVRCMeta(object):
def get_synset_words_1000(self): def get_synset_words_1000(self):
""" """
:returns a dict of {cls_number: cls_name} Returns:
dict: {cls_number: cls_name}
""" """
fname = os.path.join(self.dir, 'synset_words.txt') fname = os.path.join(self.dir, 'synset_words.txt')
assert os.path.isfile(fname) assert os.path.isfile(fname)
...@@ -46,7 +47,8 @@ class ILSVRCMeta(object): ...@@ -46,7 +47,8 @@ class ILSVRCMeta(object):
def get_synset_1000(self): def get_synset_1000(self):
""" """
:returns a dict of {cls_number: synset_id} Returns:
dict: {cls_number: synset_id}
""" """
fname = os.path.join(self.dir, 'synsets.txt') fname = os.path.join(self.dir, 'synsets.txt')
assert os.path.isfile(fname) assert os.path.isfile(fname)
...@@ -59,8 +61,10 @@ class ILSVRCMeta(object): ...@@ -59,8 +61,10 @@ class ILSVRCMeta(object):
def get_image_list(self, name): def get_image_list(self, name):
""" """
:param name: 'train' or 'val' or 'test' Args:
:returns: list of (image filename, cls) name (str): 'train' or 'val' or 'test'
Returns:
list: list of (image filename, label)
""" """
assert name in ['train', 'val', 'test'] assert name in ['train', 'val', 'test']
fname = os.path.join(self.dir, name + '.txt') fname = os.path.join(self.dir, name + '.txt')
...@@ -75,8 +79,10 @@ class ILSVRCMeta(object): ...@@ -75,8 +79,10 @@ class ILSVRCMeta(object):
def get_per_pixel_mean(self, size=None): def get_per_pixel_mean(self, size=None):
""" """
:param size: return image size in [h, w]. default to (256, 256) Args:
:returns: per-pixel mean as an array of shape (h, w, 3) in range [0, 255] size (tuple): image size in (h, w). Defaults to (256, 256).
Returns:
np.ndarray: per-pixel mean of shape (h, w, 3 (BGR)) in range [0, 255].
""" """
obj = self.caffepb.BlobProto() obj = self.caffepb.BlobProto()
...@@ -91,18 +97,26 @@ class ILSVRCMeta(object): ...@@ -91,18 +97,26 @@ class ILSVRCMeta(object):
class ILSVRC12(RNGDataFlow): class ILSVRC12(RNGDataFlow):
"""
Produces ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999],
and optionally a bounding box of [xmin, ymin, xmax, ymax].
"""
def __init__(self, dir, name, meta_dir=None, shuffle=True, def __init__(self, dir, name, meta_dir=None, shuffle=True,
dir_structure='original', include_bb=False): dir_structure='original', include_bb=False):
""" """
:param dir: A directory containing a subdir named `name`, where the Args:
original ILSVRC12_`name`.tar gets decompressed. dir (str): A directory containing a subdir named ``name``, where the
:param name: 'train' or 'val' or 'test' original ``ILSVRC12_img_{name}.tar`` gets decompressed.
:param dir_structure: The dir structure of 'val' and 'test'. name (str): 'train' or 'val' or 'test'.
If is 'original' then keep the original decompressed directory with list shuffle (bool): shuffle the dataset.
of image files (as below). If set to 'train', use the the same dir_structure (str): The dir structure of 'val' and 'test' directory.
directory structure as 'train/', with class name as subdirectories. If is 'original', it expects the original decompressed
:param include_bb: Include the bounding box. Maybe useful in training. directory, which only has list of image files (as below).
If set to 'train', it expects the same two-level
directory structure simlar to 'train/'.
include_bb (bool): Include the bounding box. Maybe useful in training.
Examples:
When `dir_structure=='original'`, `dir` should have the following structure: When `dir_structure=='original'`, `dir` should have the following structure:
...@@ -120,22 +134,16 @@ class ILSVRC12(RNGDataFlow): ...@@ -120,22 +134,16 @@ class ILSVRC12(RNGDataFlow):
test/ test/
ILSVRC2012_test_00000001.JPEG ILSVRC2012_test_00000001.JPEG
... ...
bbox/
n02134418/
n02134418_198.xml
...
...
After decompress ILSVRC12_img_train.tar, you can use the following With ILSVRC12_img_*.tar, you can use the following
command to build the above structure for `train/`: command to build the above structure:
.. code-block:: none .. code-block:: none
tar xvf ILSVRC12_img_train.tar -C train && cd train mkdir val && tar xvf ILSVRC12_img_val.tar -C val
mkdir test && tar xvf ILSVRC12_img_test.tar -C test
mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
Or:
for i in *.tar; do dir=${i%.tar}; echo $dir; mkdir -p $dir; tar xf $i -C $dir; done
""" """
assert name in ['train', 'test', 'val'] assert name in ['train', 'test', 'val']
self.full_dir = os.path.join(dir, name) self.full_dir = os.path.join(dir, name)
...@@ -158,10 +166,6 @@ class ILSVRC12(RNGDataFlow): ...@@ -158,10 +166,6 @@ class ILSVRC12(RNGDataFlow):
return len(self.imglist) return len(self.imglist)
def get_data(self): def get_data(self):
"""
Produce original images of shape [h, w, 3(BGR)], and label,
and optionally a bbox of [xmin, ymin, xmax, ymax]
"""
idxs = np.arange(len(self.imglist)) idxs = np.arange(len(self.imglist))
add_label_to_fname = (self.name != 'train' and self.dir_structure != 'original') add_label_to_fname = (self.name != 'train' and self.dir_structure != 'original')
if self.shuffle: if self.shuffle:
......
...@@ -65,14 +65,15 @@ def extract_labels(filename): ...@@ -65,14 +65,15 @@ def extract_labels(filename):
class Mnist(RNGDataFlow): class Mnist(RNGDataFlow):
""" """
Return [image, label], Produces [image, label] in MNIST dataset,
image is 28x28 in the range [0,1] image is 28x28 in the range [0,1], label is an int.
""" """
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
""" """
Args: Args:
train_or_test: string either 'train' or 'test' train_or_test (str): either 'train' or 'test'
shuffle (bool): shuffle the dataset
""" """
if dir is None: if dir is None:
dir = get_dataset_path('mnist_data') dir = get_dataset_path('mnist_data')
......
...@@ -21,15 +21,17 @@ SVHN_URL = "http://ufldl.stanford.edu/housenumbers/" ...@@ -21,15 +21,17 @@ SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
class SVHNDigit(RNGDataFlow): class SVHNDigit(RNGDataFlow):
""" """
SVHN Cropped Digit Dataset. `SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Cropped Digit Dataset.
return img of 32x32x3, label of 0-9 Produces [img, label], img of 32x32x3 in range [0,255], label of 0-9
""" """
_Cache = {} _Cache = {}
def __init__(self, name, data_dir=None, shuffle=True): def __init__(self, name, data_dir=None, shuffle=True):
""" """
:param name: 'train', 'test', or 'extra' Args:
:param data_dir: a directory containing the original {train,test,extra}_32x32.mat name (str): 'train', 'test', or 'extra'.
data_dir (str): a directory containing the original {train,test,extra}_32x32.mat.
shuffle (bool): shuffle the dataset.
""" """
self.shuffle = shuffle self.shuffle = shuffle
......
...@@ -18,13 +18,11 @@ def read_json(fname): ...@@ -18,13 +18,11 @@ def read_json(fname):
f.close() f.close()
return ret return ret
# TODO shuffle
class VisualQA(DataFlow): class VisualQA(DataFlow):
""" """
Visual QA dataset. See http://visualqa.org/ `Visual QA <http://visualqa.org/>`_ dataset.
Simply read q/a json file and produce q/a pairs in their original format. It simply reads q/a json file and produce q/a pairs in their original format.
""" """
def __init__(self, question_file, annotation_file): def __init__(self, question_file, annotation_file):
......
...@@ -23,17 +23,17 @@ except ImportError: ...@@ -23,17 +23,17 @@ except ImportError:
else: else:
__all__.extend(['dump_dataflow_to_lmdb']) __all__.extend(['dump_dataflow_to_lmdb'])
# TODO pass a name_func to write label as filename?
def dump_dataset_images(ds, dirname, max_count=None, index=0): def dump_dataset_images(ds, dirname, max_count=None, index=0):
""" Dump images from a `DataFlow` to a directory. """ Dump images from a DataFlow to a directory.
:param ds: a `DataFlow` instance. Args:
:param dirname: name of the directory. ds (DataFlow): the DataFlow to dump.
:param max_count: max number of images to dump dirname (str): name of the directory.
:param index: the index of the image component in a data point. max_count (int): limit max number of images to dump. Defaults to unlimited.
index (int): the index of the image component in the data point.
""" """
# TODO pass a name_func to write label as filename?
mkdir_p(dirname) mkdir_p(dirname)
if max_count is None: if max_count is None:
max_count = sys.maxint max_count = sys.maxint
...@@ -48,9 +48,15 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0): ...@@ -48,9 +48,15 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
def dump_dataflow_to_lmdb(ds, lmdb_path): def dump_dataflow_to_lmdb(ds, lmdb_path):
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index """
and the data is the serialized datapoint. Dump a Dataflow to a lmdb database, where the keys are indices and values
The output database can be read directly by `LMDBDataPoint` are serialized datapoints.
The output database can be read directly by
:class:`tensorpack.dataflow.LMDBDataPoint`.
Args:
ds (DataFlow): the DataFlow to dump.
lmdb_path (str): output path. Either a directory or a mdb file.
""" """
assert isinstance(ds, DataFlow), type(ds) assert isinstance(ds, DataFlow), type(ds)
isdir = os.path.isdir(lmdb_path) isdir = os.path.isdir(lmdb_path)
...@@ -80,15 +86,20 @@ def dump_dataflow_to_lmdb(ds, lmdb_path): ...@@ -80,15 +86,20 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
def dataflow_to_process_queue(ds, size, nr_consumer): def dataflow_to_process_queue(ds, size, nr_consumer):
""" """
Convert a `DataFlow` to a multiprocessing.Queue. Convert a DataFlow to a :class:`multiprocessing.Queue`.
The dataflow will only be reset in the spawned process. The DataFlow will only be reset in the spawned process.
:param ds: a `DataFlow` Args:
:param size: size of the queue ds (DataFlow): the DataFlow to dump.
:param nr_consumer: number of consumer of the queue. size (int): size of the queue
will add this many of `DIE` sentinel to the end of the queue. nr_consumer (int): number of consumer of the queue.
:returns: (queue, process). The process will take data from `ds` to fill The producer will add this many of ``DIE`` sentinel to the end of the queue.
the queue once you start it. Each element is (task_id, dp).
Returns:
tuple(queue, process):
The process will take data from ``ds`` and fill
the queue, once you start it. Each element in the queue is (idx,
dp). idx can be the ``DIE`` sentinel when ``ds`` is exhausted.
""" """
q = mp.Queue(size) q = mp.Queue(size)
......
...@@ -26,7 +26,7 @@ try: ...@@ -26,7 +26,7 @@ try:
except ImportError: except ImportError:
logger.warn_dependency("LMDBData", 'lmdb') logger.warn_dependency("LMDBData", 'lmdb')
else: else:
__all__.extend(['LMDBData', 'CaffeLMDB', 'LMDBDataDecoder', 'LMDBDataPoint']) __all__.extend(['LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint', 'CaffeLMDB'])
try: try:
import sklearn.datasets import sklearn.datasets
...@@ -40,19 +40,23 @@ else: ...@@ -40,19 +40,23 @@ else:
Adapters for different data format. Adapters for different data format.
""" """
# TODO lazy load
class HDF5Data(RNGDataFlow): class HDF5Data(RNGDataFlow):
""" """
Zip data from different paths in an HDF5 file. Will load all data into memory. Zip data from different paths in an HDF5 file.
Warning:
The current implementation will load all data into memory.
""" """
# TODO lazy load
def __init__(self, filename, data_paths, shuffle=True): def __init__(self, filename, data_paths, shuffle=True):
""" """
:param filename: h5 data file. Args:
:param data_paths: list of h5 paths to zipped. For example ['images', 'labels'] filename (str): h5 data file.
:param shuffle: shuffle the order of all data. data_paths (list): list of h5 paths to zipped.
For example `['images', 'labels']`.
shuffle (bool): shuffle all data.
""" """
self.f = h5py.File(filename, 'r') self.f = h5py.File(filename, 'r')
logger.info("Loading {} to memory...".format(filename)) logger.info("Loading {} to memory...".format(filename))
...@@ -74,9 +78,13 @@ class HDF5Data(RNGDataFlow): ...@@ -74,9 +78,13 @@ class HDF5Data(RNGDataFlow):
class LMDBData(RNGDataFlow): class LMDBData(RNGDataFlow):
""" Read a lmdb and produce k,v pair """ """ Read a LMDB database and produce (k,v) pairs """
def __init__(self, lmdb_path, shuffle=True): def __init__(self, lmdb_path, shuffle=True):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
"""
self._lmdb_path = lmdb_path self._lmdb_path = lmdb_path
self._shuffle = shuffle self._shuffle = shuffle
self.open_lmdb() self.open_lmdb()
...@@ -122,11 +130,14 @@ class LMDBData(RNGDataFlow): ...@@ -122,11 +130,14 @@ class LMDBData(RNGDataFlow):
class LMDBDataDecoder(LMDBData): class LMDBDataDecoder(LMDBData):
""" Read a LMDB database and produce a decoded output."""
def __init__(self, lmdb_path, decoder, shuffle=True): def __init__(self, lmdb_path, decoder, shuffle=True):
""" """
:param decoder: a function taking k, v and return a data point, Args:
or return None to skip lmdb_path (str): a directory or a file.
decoder (k,v -> dp | None): a function taking k, v and returning a datapoint,
or return None to discard.
shuffle (bool): shuffle the keys or not.
""" """
super(LMDBDataDecoder, self).__init__(lmdb_path, shuffle) super(LMDBDataDecoder, self).__init__(lmdb_path, shuffle)
self.decoder = decoder self.decoder = decoder
...@@ -139,17 +150,31 @@ class LMDBDataDecoder(LMDBData): ...@@ -139,17 +150,31 @@ class LMDBDataDecoder(LMDBData):
class LMDBDataPoint(LMDBDataDecoder): class LMDBDataPoint(LMDBDataDecoder):
""" Read a LMDB file where each value is a serialized datapoint""" """ Read a LMDB file and produce deserialized values.
This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """
def __init__(self, lmdb_path, shuffle=True): def __init__(self, lmdb_path, shuffle=True):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
"""
super(LMDBDataPoint, self).__init__( super(LMDBDataPoint, self).__init__(
lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle) lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle)
class CaffeLMDB(LMDBDataDecoder): class CaffeLMDB(LMDBDataDecoder):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """ """
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label].
"""
def __init__(self, lmdb_path, shuffle=True): def __init__(self, lmdb_path, shuffle=True):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
"""
cpb = get_caffe_pb() cpb = get_caffe_pb()
def decoder(k, v): def decoder(k, v):
...@@ -168,9 +193,14 @@ class CaffeLMDB(LMDBDataDecoder): ...@@ -168,9 +193,14 @@ class CaffeLMDB(LMDBDataDecoder):
class SVMLightData(RNGDataFlow): class SVMLightData(RNGDataFlow):
""" Read X,y from a svmlight file """ """ Read X,y from a svmlight file, and produce [X_i, y_i] pairs. """
def __init__(self, filename, shuffle=True): def __init__(self, filename, shuffle=True):
"""
Args:
filename (str): input file
shuffle (bool): shuffle the data
"""
self.X, self.y = sklearn.datasets.load_svmlight_file(filename) self.X, self.y = sklearn.datasets.load_svmlight_file(filename)
self.X = np.asarray(self.X.todense()) self.X = np.asarray(self.X.todense())
self.shuffle = shuffle self.shuffle = shuffle
......
...@@ -12,13 +12,13 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents'] ...@@ -12,13 +12,13 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
class ImageFromFile(RNGDataFlow): class ImageFromFile(RNGDataFlow):
""" Produce images read from a list of files. """
def __init__(self, files, channel=3, resize=None, shuffle=False): def __init__(self, files, channel=3, resize=None, shuffle=False):
""" """
Generate images of 1 channel or 3 channels (in RGB order) from list of files. Args:
:param files: list of file paths files (list): list of file paths.
:param channel: 1 or 3 channel channel (int): 1 or 3. Produce RGB images if channel==3.
:param resize: a (h, w) tuple. If given, will force a resize resize (tuple): (h, w). If given, resize the image.
""" """
assert len(files), "No image files given to ImageFromFile!" assert len(files), "No image files given to ImageFromFile!"
self.files = files self.files = files
...@@ -45,14 +45,15 @@ class ImageFromFile(RNGDataFlow): ...@@ -45,14 +45,15 @@ class ImageFromFile(RNGDataFlow):
class AugmentImageComponent(MapDataComponent): class AugmentImageComponent(MapDataComponent):
"""
Apply image augmentors on 1 component.
"""
def __init__(self, ds, augmentors, index=0): def __init__(self, ds, augmentors, index=0):
""" """
Augment the image component of datapoints Args:
:param ds: a `DataFlow` instance. ds (DataFlow): input DataFlow.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
:param index: the index (or list of indices) of the image component index (int): the index of the image component to be augmented.
in the produced datapoints by `ds`. default to be 0
""" """
if isinstance(augmentors, AugmentorList): if isinstance(augmentors, AugmentorList):
self.augs = augmentors self.augs = augmentors
...@@ -67,12 +68,16 @@ class AugmentImageComponent(MapDataComponent): ...@@ -67,12 +68,16 @@ class AugmentImageComponent(MapDataComponent):
class AugmentImageComponents(MapData): class AugmentImageComponents(MapData):
"""
Apply image augmentors on several components, with shared augmentation parameters.
"""
def __init__(self, ds, augmentors, index=(0, 1)): def __init__(self, ds, augmentors, index=(0, 1)):
""" Augment a list of images of the same shape, with the same parameters """
:param ds: a `DataFlow` instance. Args:
:param augmentors: a list of `ImageAugmentor` instance to be applied in order. ds (DataFlow): input DataFlow.
:param index: tuple of indices of the image components augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of components.
""" """
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self.ds = ds self.ds = ds
......
...@@ -24,6 +24,7 @@ class Augmentor(object): ...@@ -24,6 +24,7 @@ class Augmentor(object):
setattr(self, k, v) setattr(self, k, v)
def reset_state(self): def reset_state(self):
""" reset rng and other state """
self.rng = get_rng(self) self.rng = get_rng(self)
def augment(self, d): def augment(self, d):
...@@ -64,9 +65,13 @@ class ImageAugmentor(Augmentor): ...@@ -64,9 +65,13 @@ class ImageAugmentor(Augmentor):
def augment(self, img): def augment(self, img):
""" """
Perform augmentation on the image in-place. Perform augmentation on the image (possibly) in-place.
:param img: an [h,w] or [h,w,c] image
:returns: the augmented image, always of type 'float32' Args:
img (np.ndarray): an [h,w] or [h,w,c] image.
Returns:
np.ndarray: the augmented image, always of type float32.
""" """
img, params = self._augment_return_params(img) img, params = self._augment_return_params(img)
return img return img
...@@ -82,7 +87,8 @@ class AugmentorList(ImageAugmentor): ...@@ -82,7 +87,8 @@ class AugmentorList(ImageAugmentor):
def __init__(self, augmentors): def __init__(self, augmentors):
""" """
:param augmentors: list of `ImageAugmentor` instance to be applied Args:
augmentors (list): list of :class:`ImageAugmentor` instance to be applied.
""" """
self.augs = augmentors self.augs = augmentors
super(AugmentorList, self).__init__() super(AugmentorList, self).__init__()
......
...@@ -10,7 +10,7 @@ from six.moves import range ...@@ -10,7 +10,7 @@ from six.moves import range
import numpy as np import numpy as np
__all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop', __all__ = ['RandomCrop', 'CenterCrop', 'FixedCrop',
'RandomCropRandomShape', 'perturb_BB', 'RandomCropAroundBox'] 'perturb_BB', 'RandomCropAroundBox', 'RandomCropRandomShape']
class RandomCrop(ImageAugmentor): class RandomCrop(ImageAugmentor):
...@@ -18,7 +18,8 @@ class RandomCrop(ImageAugmentor): ...@@ -18,7 +18,8 @@ class RandomCrop(ImageAugmentor):
def __init__(self, crop_shape): def __init__(self, crop_shape):
""" """
:param crop_shape: a shape like (h, w) Args:
crop_shape: (h, w) tuple or a int
""" """
crop_shape = shape2d(crop_shape) crop_shape = shape2d(crop_shape)
super(RandomCrop, self).__init__() super(RandomCrop, self).__init__()
...@@ -47,7 +48,8 @@ class CenterCrop(ImageAugmentor): ...@@ -47,7 +48,8 @@ class CenterCrop(ImageAugmentor):
def __init__(self, crop_shape): def __init__(self, crop_shape):
""" """
:param crop_shape: a shape like (h, w) Args:
crop_shape: (h, w) tuple or a int
""" """
crop_shape = shape2d(crop_shape) crop_shape = shape2d(crop_shape)
self._init(locals()) self._init(locals())
...@@ -67,9 +69,8 @@ class FixedCrop(ImageAugmentor): ...@@ -67,9 +69,8 @@ class FixedCrop(ImageAugmentor):
def __init__(self, rect): def __init__(self, rect):
""" """
Two arguments defined the range in both axes to crop, min inclued, max excluded. Args:
rect(Rect): min included, max excluded.
:param rect: a `Rect` instance
""" """
self._init(locals()) self._init(locals())
...@@ -86,12 +87,15 @@ def perturb_BB(image_shape, bb, max_pertub_pixel, ...@@ -86,12 +87,15 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
max_try=100): max_try=100):
""" """
Perturb a bounding box. Perturb a bounding box.
:param image_shape: [h, w]
:param bb: a `Rect` instance Args:
:param max_pertub_pixel: pertubation on each coordinate image_shape: [h, w]
:param max_aspect_ratio_diff: result can't have an aspect ratio too different from the original bb (Rect): original bounding box
:param max_try: if cannot find a valid bounding box, return the original max_pertub_pixel: pertubation on each coordinate
:returns: new bounding box max_aspect_ratio_diff: result can't have an aspect ratio too different from the original
max_try: if cannot find a valid bounding box, return the original
Returns:
new bounding box
""" """
orig_ratio = bb.h * 1.0 / bb.w orig_ratio = bb.h * 1.0 / bb.w
if rng is None: if rng is None:
...@@ -117,13 +121,15 @@ def perturb_BB(image_shape, bb, max_pertub_pixel, ...@@ -117,13 +121,15 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
class RandomCropAroundBox(ImageAugmentor): class RandomCropAroundBox(ImageAugmentor):
""" """
Crop a box around a bounding box Crop a box around a bounding box by some random pertubation
""" """
def __init__(self, perturb_ratio, max_aspect_ratio_diff=0.3): def __init__(self, perturb_ratio, max_aspect_ratio_diff=0.3):
""" """
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)] Args:
:param max_aspect_ratio_diff: keep aspect ratio within the range perturb_ratio (float): perturb distance will be in
``[0, perturb_ratio * sqrt(w * h)]``
max_aspect_ratio_diff (float): keep aspect ratio difference within the range
""" """
super(RandomCropAroundBox, self).__init__() super(RandomCropAroundBox, self).__init__()
self._init(locals()) self._init(locals())
...@@ -144,14 +150,18 @@ class RandomCropAroundBox(ImageAugmentor): ...@@ -144,14 +150,18 @@ class RandomCropAroundBox(ImageAugmentor):
class RandomCropRandomShape(ImageAugmentor): class RandomCropRandomShape(ImageAugmentor):
""" Random crop with a random shape"""
def __init__(self, wmin, hmin, def __init__(self, wmin, hmin,
wmax=None, hmax=None, wmax=None, hmax=None,
max_aspect_ratio=None): max_aspect_ratio=None):
""" """
Randomly crop a box of shape (h, w), sampled from [min, max](inclusive). Randomly crop a box of shape (h, w), sampled from [min, max] (both inclusive).
If max is None, will use the input image shape. If max is None, will use the input image shape.
max_aspect_ratio is the upper bound of max(w,h)/min(w,h)
Args:
wmin, hmin, wmax, hmax: range to sample shape.
max_aspect_ratio (float): the upper bound of ``max(w,h)/min(w,h)``.
""" """
if max_aspect_ratio is None: if max_aspect_ratio is None:
max_aspect_ratio = 9999999 max_aspect_ratio = 9999999
......
...@@ -6,13 +6,12 @@ from .base import ImageAugmentor ...@@ -6,13 +6,12 @@ from .base import ImageAugmentor
from ...utils import logger from ...utils import logger
import numpy as np import numpy as np
__all__ = ['GaussianDeform', 'GaussianMap'] __all__ = ['GaussianDeform']
# TODO really needs speedup
class GaussianMap(object): class GaussianMap(object):
""" Generate gaussian weighted deformation map""" """ Generate gaussian weighted deformation map"""
# TODO really needs speedup
def __init__(self, image_shape, sigma=0.5): def __init__(self, image_shape, sigma=0.5):
assert len(image_shape) == 2 assert len(image_shape) == 2
...@@ -20,6 +19,10 @@ class GaussianMap(object): ...@@ -20,6 +19,10 @@ class GaussianMap(object):
self.sigma = sigma self.sigma = sigma
def get_gaussian_weight(self, anchor): def get_gaussian_weight(self, anchor):
"""
Args:
anchor: coordinate of the center
"""
ret = np.zeros(self.shape, dtype='float32') ret = np.zeros(self.shape, dtype='float32')
y, x = np.mgrid[:self.shape[0], :self.shape[1]] y, x = np.mgrid[:self.shape[0], :self.shape[1]]
...@@ -55,20 +58,20 @@ def np_sample(img, coords): ...@@ -55,20 +58,20 @@ def np_sample(img, coords):
img[ucoory, lcoorx, :] * diffy * ndiffx img[ucoory, lcoorx, :] * diffy * ndiffx
return ret[:, :, 0, :] return ret[:, :, 0, :]
# TODO input/output with different shape
class GaussianDeform(ImageAugmentor): class GaussianDeform(ImageAugmentor):
""" """
Some kind of deformation. Quite slow. Some kind of slow deformation.
""" """
# TODO input/output with different shape
def __init__(self, anchors, shape, sigma=0.5, randrange=None): def __init__(self, anchors, shape, sigma=0.5, randrange=None):
""" """
:param anchors: in [0,1] coordinate Args:
:param shape: image shape in [h, w] anchors (list): list of center coordinates in range [0,1].
:param sigma: sigma for Gaussian weight shape(list or tuple): image shape in [h, w].
:param randrange: default to shape[0] / 8 sigma (float): sigma for Gaussian weight
randrange (int): offset range. Defaults to shape[0] / 8
""" """
logger.warn("GaussianDeform is slow. Consider using it with 4 or more prefetching processes.") logger.warn("GaussianDeform is slow. Consider using it with 4 or more prefetching processes.")
super(GaussianDeform, self).__init__() super(GaussianDeform, self).__init__()
......
...@@ -17,8 +17,11 @@ class Rotation(ImageAugmentor): ...@@ -17,8 +17,11 @@ class Rotation(ImageAugmentor):
interp=cv2.INTER_CUBIC, interp=cv2.INTER_CUBIC,
border=cv2.BORDER_REPLICATE): border=cv2.BORDER_REPLICATE):
""" """
:param max_deg: max abs value of the rotation degree Args:
:param center_range: the location of the rotation center max_deg (float): max abs value of the rotation degree (in angle).
center_range (tuple): (min, max) range of the random rotation center.
interp: cv2 interpolation method
border: cv2 border method
""" """
super(Rotation, self).__init__() super(Rotation, self).__init__()
self._init(locals()) self._init(locals())
...@@ -36,11 +39,16 @@ class Rotation(ImageAugmentor): ...@@ -36,11 +39,16 @@ class Rotation(ImageAugmentor):
class RotationAndCropValid(ImageAugmentor): class RotationAndCropValid(ImageAugmentor):
""" Random rotate and crop the largest possible rect without the border """ Random rotate and then crop the largest possible rectangle.
This will produce images of different shapes. Note that this will produce images of different shapes.
""" """
def __init__(self, max_deg, interp=cv2.INTER_CUBIC): def __init__(self, max_deg, interp=cv2.INTER_CUBIC):
"""
Args:
max_deg (float): max abs value of the rotation degree (in angle).
interp: cv2 interpolation method
"""
super(RotationAndCropValid, self).__init__() super(RotationAndCropValid, self).__init__()
self._init(locals()) self._init(locals())
...@@ -63,7 +71,10 @@ class RotationAndCropValid(ImageAugmentor): ...@@ -63,7 +71,10 @@ class RotationAndCropValid(ImageAugmentor):
@staticmethod @staticmethod
def largest_rotated_rect(w, h, angle): def largest_rotated_rect(w, h, angle):
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """ """
Get largest rectangle after rotation.
http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders
"""
angle = angle / 180.0 * math.pi angle = angle / 180.0 * math.pi
if w <= 0 or h <= 0: if w <= 0 or h <= 0:
return 0, 0 return 0, 0
......
...@@ -12,9 +12,8 @@ __all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur', ...@@ -12,9 +12,8 @@ __all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur',
class Brightness(ImageAugmentor): class Brightness(ImageAugmentor):
""" """
Random adjust brightness. Randomly adjust brightness.
""" """
def __init__(self, delta, clip=True): def __init__(self, delta, clip=True):
""" """
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True. Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
...@@ -36,14 +35,14 @@ class Brightness(ImageAugmentor): ...@@ -36,14 +35,14 @@ class Brightness(ImageAugmentor):
class Contrast(ImageAugmentor): class Contrast(ImageAugmentor):
""" """
Apply x = (x - mean) * contrast_factor + mean to each channel Apply ``x = (x - mean) * contrast_factor + mean`` to each channel.
and clip to [0, 255]
""" """
def __init__(self, factor_range, clip=True): def __init__(self, factor_range, clip=True):
""" """
:param factor_range: an interval to random sample the `contrast_factor`. Args:
:param clip: boolean. factor_range (list or tuple): an interval to randomly sample the `contrast_factor`.
clip (bool): clip to [0, 255] if True.
""" """
super(Contrast, self).__init__() super(Contrast, self).__init__()
self._init(locals()) self._init(locals())
...@@ -61,14 +60,15 @@ class Contrast(ImageAugmentor): ...@@ -61,14 +60,15 @@ class Contrast(ImageAugmentor):
class MeanVarianceNormalize(ImageAugmentor): class MeanVarianceNormalize(ImageAugmentor):
""" """
Linearly scales image to have zero mean and unit norm. Linearly scales the image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev ``x = (x - mean) / adjusted_stddev``
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels)) where ``adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))``
""" """
def __init__(self, all_channel=True): def __init__(self, all_channel=True):
""" """
:param all_channel: if True, normalize all channels together. else separately. Args:
all_channel (bool): if True, normalize all channels together. else separately.
""" """
self.all_channel = all_channel self.all_channel = all_channel
...@@ -85,9 +85,13 @@ class MeanVarianceNormalize(ImageAugmentor): ...@@ -85,9 +85,13 @@ class MeanVarianceNormalize(ImageAugmentor):
class GaussianBlur(ImageAugmentor): class GaussianBlur(ImageAugmentor):
""" Gaussian blur the image with random window size"""
def __init__(self, max_size=3): def __init__(self, max_size=3):
""":params max_size: (maximum kernel size-1)/2""" """
Args:
max_size (int): max possible Gaussian window size would be 2 * max_size + 1
"""
super(GaussianBlur, self).__init__() super(GaussianBlur, self).__init__()
self._init(locals()) self._init(locals())
...@@ -103,8 +107,12 @@ class GaussianBlur(ImageAugmentor): ...@@ -103,8 +107,12 @@ class GaussianBlur(ImageAugmentor):
class Gamma(ImageAugmentor): class Gamma(ImageAugmentor):
""" Randomly adjust gamma """
def __init__(self, range=(-0.5, 0.5)): def __init__(self, range=(-0.5, 0.5)):
"""
Args:
range(list or tuple): gamma range
"""
super(Gamma, self).__init__() super(Gamma, self).__init__()
self._init(locals()) self._init(locals())
...@@ -119,8 +127,13 @@ class Gamma(ImageAugmentor): ...@@ -119,8 +127,13 @@ class Gamma(ImageAugmentor):
class Clip(ImageAugmentor): class Clip(ImageAugmentor):
""" Clip the pixel values """
def __init__(self, min=0, max=255): def __init__(self, min=0, max=255):
"""
Args:
min, max: the clip range
"""
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _augment(self, img, _):
...@@ -129,10 +142,15 @@ class Clip(ImageAugmentor): ...@@ -129,10 +142,15 @@ class Clip(ImageAugmentor):
class Saturation(ImageAugmentor): class Saturation(ImageAugmentor):
""" Randomly adjust saturation.
Follows the implementation in `fb.resnet.torch
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`_
"""
def __init__(self, alpha=0.4): def __init__(self, alpha=0.4):
""" Saturation, """
see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218 Args:
alpha(float): maximum saturation change.
""" """
super(Saturation, self).__init__() super(Saturation, self).__init__()
assert alpha < 1 assert alpha < 1
...@@ -147,14 +165,19 @@ class Saturation(ImageAugmentor): ...@@ -147,14 +165,19 @@ class Saturation(ImageAugmentor):
class Lighting(ImageAugmentor): class Lighting(ImageAugmentor):
""" Lighting noise, as in the paper
`ImageNet Classification with Deep Convolutional Neural Networks
<https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_.
The implementation follows `fb.resnet.torch
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184>`_.
"""
def __init__(self, std, eigval, eigvec): def __init__(self, std, eigval, eigvec):
""" Lighting noise. """
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex` Args:
The implementation follows 'fb.resnet.torch': std (float): maximum standard deviation
https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184 eigval: a vector of (3,). The eigenvalues of 3 channels.
eigvec: a 3x3 matrix. Each column is one eigen vector.
:param eigvec: each column is one eigen vector
""" """
eigval = np.asarray(eigval) eigval = np.asarray(eigval)
eigvec = np.asarray(eigvec) eigvec = np.asarray(eigvec)
......
...@@ -11,15 +11,22 @@ __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug', ...@@ -11,15 +11,22 @@ __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
class Identity(ImageAugmentor): class Identity(ImageAugmentor):
""" A no-op augmentor """
def _augment(self, img, _): def _augment(self, img, _):
return img return img
class RandomApplyAug(ImageAugmentor): class RandomApplyAug(ImageAugmentor):
""" Randomly apply the augmentor with a prob. Otherwise do nothing""" """ Randomly apply the augmentor with a probability.
Otherwise do nothing
"""
def __init__(self, aug, prob): def __init__(self, aug, prob):
"""
Args:
aug (ImageAugmentor): an augmentor
prob (float): the probability
"""
self._init(locals()) self._init(locals())
super(RandomApplyAug, self).__init__() super(RandomApplyAug, self).__init__()
...@@ -43,10 +50,11 @@ class RandomApplyAug(ImageAugmentor): ...@@ -43,10 +50,11 @@ class RandomApplyAug(ImageAugmentor):
class RandomChooseAug(ImageAugmentor): class RandomChooseAug(ImageAugmentor):
""" Randomly choose one from a list of augmentors """
def __init__(self, aug_lists): def __init__(self, aug_lists):
""" """
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple Args:
aug_lists (list): list of augmentors, or list of (augmentor, probability) tuples
""" """
if isinstance(aug_lists[0], (tuple, list)): if isinstance(aug_lists[0], (tuple, list)):
prob = [k[1] for k in aug_lists] prob = [k[1] for k in aug_lists]
...@@ -73,11 +81,15 @@ class RandomChooseAug(ImageAugmentor): ...@@ -73,11 +81,15 @@ class RandomChooseAug(ImageAugmentor):
class RandomOrderAug(ImageAugmentor): class RandomOrderAug(ImageAugmentor):
"""
Apply the augmentors with randomized order.
"""
def __init__(self, aug_lists): def __init__(self, aug_lists):
""" """
Shuffle the augmentors into random order. Args:
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple aug_lists (list): list of augmentors.
The augmentors are assumed to not change the shape of images.
""" """
self._init(locals()) self._init(locals())
super(RandomOrderAug, self).__init__() super(RandomOrderAug, self).__init__()
...@@ -109,7 +121,8 @@ class MapImage(ImageAugmentor): ...@@ -109,7 +121,8 @@ class MapImage(ImageAugmentor):
def __init__(self, func): def __init__(self, func):
""" """
:param func: a function which takes a image array and return a augmented one Args:
func: a function which takes an image array and return an augmented one
""" """
self.func = func self.func = func
......
...@@ -11,8 +11,13 @@ __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise'] ...@@ -11,8 +11,13 @@ __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
class JpegNoise(ImageAugmentor): class JpegNoise(ImageAugmentor):
""" Random Jpeg noise. """
def __init__(self, quality_range=(40, 100)): def __init__(self, quality_range=(40, 100)):
"""
Args:
quality_range (tuple): range to sample Jpeg quality
"""
super(JpegNoise, self).__init__() super(JpegNoise, self).__init__()
self._init(locals()) self._init(locals())
...@@ -25,10 +30,14 @@ class JpegNoise(ImageAugmentor): ...@@ -25,10 +30,14 @@ class JpegNoise(ImageAugmentor):
class GaussianNoise(ImageAugmentor): class GaussianNoise(ImageAugmentor):
"""
Add random gaussian noise N(0, sigma^2) of the same shape to img.
"""
def __init__(self, sigma=1, clip=True): def __init__(self, sigma=1, clip=True):
""" """
Add a gaussian noise N(0, sigma^2) of the same shape to img. Args:
sigma (float): stddev of the Gaussian distribution.
clip (bool): clip the result to [0,255] in the end.
""" """
super(GaussianNoise, self).__init__() super(GaussianNoise, self).__init__()
self._init(locals()) self._init(locals())
...@@ -44,11 +53,15 @@ class GaussianNoise(ImageAugmentor): ...@@ -44,11 +53,15 @@ class GaussianNoise(ImageAugmentor):
class SaltPepperNoise(ImageAugmentor): class SaltPepperNoise(ImageAugmentor):
def __init__(self, white_prob=0.05, black_prob=0.05):
""" Salt and pepper noise. """ Salt and pepper noise.
Randomly set some elements in img to 0 or 255, regardless of its channels. Randomly set some elements in img to 0 or 255, regardless of its channels.
""" """
def __init__(self, white_prob=0.05, black_prob=0.05):
"""
Args:
white_prob (float), black_prob (float): probabilities setting an element to 255 or 0.
"""
assert white_prob + black_prob <= 1, "Sum of probabilities cannot be greater than 1" assert white_prob + black_prob <= 1, "Sum of probabilities cannot be greater than 1"
super(SaltPepperNoise, self).__init__() super(SaltPepperNoise, self).__init__()
self._init(locals()) self._init(locals())
......
...@@ -13,20 +13,18 @@ __all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge'] ...@@ -13,20 +13,18 @@ __all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge']
class Flip(ImageAugmentor): class Flip(ImageAugmentor):
""" """
Random flip. Random flip the image either horizontally or vertically.
""" """
def __init__(self, horiz=False, vert=False, prob=0.5): def __init__(self, horiz=False, vert=False, prob=0.5):
""" """
Only one of horiz, vert can be set. Args:
horiz (bool): use horizontal flip.
:param horiz: whether or not apply horizontal flip. vert (bool): use vertical flip.
:param vert: whether or not apply vertical flip. prob (float): probability of flip.
:param prob: probability of flip.
""" """
super(Flip, self).__init__() super(Flip, self).__init__()
if horiz and vert: if horiz and vert:
raise ValueError("Please use two Flip instead.") raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
elif horiz: elif horiz:
self.code = 1 self.code = 1
elif vert: elif vert:
...@@ -53,7 +51,9 @@ class Resize(ImageAugmentor): ...@@ -53,7 +51,9 @@ class Resize(ImageAugmentor):
def __init__(self, shape, interp=cv2.INTER_CUBIC): def __init__(self, shape, interp=cv2.INTER_CUBIC):
""" """
:param shape: shape in (h, w) Args:
shape: (h, w) tuple or a int
interp: cv2 interpolation method
""" """
shape = tuple(shape2d(shape)) shape = tuple(shape2d(shape))
self._init(locals()) self._init(locals())
...@@ -68,11 +68,16 @@ class Resize(ImageAugmentor): ...@@ -68,11 +68,16 @@ class Resize(ImageAugmentor):
class ResizeShortestEdge(ImageAugmentor): class ResizeShortestEdge(ImageAugmentor):
""" Resize the shortest edge to a certain number while """
keeping the aspect ratio Resize the shortest edge to a certain number while
keeping the aspect ratio.
""" """
def __init__(self, size): def __init__(self, size):
"""
Args:
size (int): the size to resize the shortest edge to.
"""
size = size * 1.0 size = size * 1.0
self._init(locals()) self._init(locals())
...@@ -87,15 +92,18 @@ class ResizeShortestEdge(ImageAugmentor): ...@@ -87,15 +92,18 @@ class ResizeShortestEdge(ImageAugmentor):
class RandomResize(ImageAugmentor): class RandomResize(ImageAugmentor):
""" randomly rescale w and h of the image""" """ Randomly rescale w and h of the image"""
def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15, def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15,
interp=cv2.INTER_CUBIC): interp=cv2.INTER_CUBIC):
""" """
:param xrange: (min, max) scaling ratio Args:
:param yrange: (min, max) scaling ratio xrange (tuple): (min, max) range of scaling ratio for w
:param minimum: (xmin, ymin). Avoid scaling down too much. yrange (tuple): (min, max) range of scaling ratio for h
:param aspect_ratio_thres: at most change k=20% aspect ratio minimum (tuple): (xmin, ymin). avoid scaling down too much.
aspect_ratio_thres (float): discard samples which change aspect ratio
larger than this threshold.
interp: cv2 interpolation method
""" """
super(RandomResize, self).__init__() super(RandomResize, self).__init__()
self._init(locals()) self._init(locals())
......
...@@ -19,9 +19,11 @@ class BackgroundFiller(object): ...@@ -19,9 +19,11 @@ class BackgroundFiller(object):
""" """
Return a proper background image of background_shape, given img Return a proper background image of background_shape, given img
:param background_shape: a shape of [h, w] Args:
:param img: an image background_shape: a shape of [h, w]
:returns: a background image img: an image
Returns:
a background image
""" """
return self._fill(background_shape, img) return self._fill(background_shape, img)
...@@ -35,7 +37,8 @@ class ConstantBackgroundFiller(BackgroundFiller): ...@@ -35,7 +37,8 @@ class ConstantBackgroundFiller(BackgroundFiller):
def __init__(self, value): def __init__(self, value):
""" """
:param value: the value to fill the background. Args:
value (float): the value to fill the background.
""" """
self.value = value self.value = value
...@@ -55,8 +58,9 @@ class CenterPaste(ImageAugmentor): ...@@ -55,8 +58,9 @@ class CenterPaste(ImageAugmentor):
def __init__(self, background_shape, background_filler=None): def __init__(self, background_shape, background_filler=None):
""" """
:param background_shape: shape of the background canvas. Args:
:param background_filler: a `BackgroundFiller` instance. Default to zero-filler. background_shape (tuple): shape of the background canvas.
background_filler (BackgroundFiller): How to fill the background. Defaults to zero-filler.
""" """
if background_filler is None: if background_filler is None:
background_filler = ConstantBackgroundFiller(0) background_filler = ConstantBackgroundFiller(0)
......
...@@ -48,15 +48,19 @@ class PrefetchProcess(mp.Process): ...@@ -48,15 +48,19 @@ class PrefetchProcess(mp.Process):
class PrefetchData(ProxyDataFlow): class PrefetchData(ProxyDataFlow):
""" """
Prefetch data from a `DataFlow` using multiprocessing Prefetch data from a DataFlow using Python multiprocessing utilities.
Note:
This is significantly slower than :class:`PrefetchDataZMQ` when data
is large.
""" """
def __init__(self, ds, nr_prefetch, nr_proc=1): def __init__(self, ds, nr_prefetch, nr_proc=1):
""" """
:param ds: a `DataFlow` instance. Args:
:param nr_prefetch: size of the queue to hold prefetched datapoints. ds (DataFlow): input DataFlow.
:param nr_proc: number of processes to use. When larger than 1, order nr_prefetch (int): size of the queue to hold prefetched datapoints.
of data points will be random. nr_proc (int): number of processes to use.
""" """
super(PrefetchData, self).__init__(ds) super(PrefetchData, self).__init__(ds)
try: try:
...@@ -85,9 +89,8 @@ class PrefetchData(ProxyDataFlow): ...@@ -85,9 +89,8 @@ class PrefetchData(ProxyDataFlow):
def BlockParallel(ds, queue_size): def BlockParallel(ds, queue_size):
# TODO more doc
""" """
Insert `BlockParallel` in dataflow pipeline to block parallelism on ds Insert ``BlockParallel`` in dataflow pipeline to block parallelism on ds.
:param ds: a `DataFlow` :param ds: a `DataFlow`
:param queue_size: size of the queue used :param queue_size: size of the queue used
...@@ -96,7 +99,6 @@ def BlockParallel(ds, queue_size): ...@@ -96,7 +99,6 @@ def BlockParallel(ds, queue_size):
class PrefetchProcessZMQ(mp.Process): class PrefetchProcessZMQ(mp.Process):
def __init__(self, ds, conn_name): def __init__(self, ds, conn_name):
""" """
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
...@@ -118,15 +120,17 @@ class PrefetchProcessZMQ(mp.Process): ...@@ -118,15 +120,17 @@ class PrefetchProcessZMQ(mp.Process):
class PrefetchDataZMQ(ProxyDataFlow): class PrefetchDataZMQ(ProxyDataFlow):
""" Work the same as `PrefetchData`, but faster. """ """
Prefetch data from a DataFlow using multiple processes, with ZMQ for
communication.
"""
def __init__(self, ds, nr_proc=1, pipedir=None): def __init__(self, ds, nr_proc=1, pipedir=None):
""" """
:param ds: a `DataFlow` instance. Args:
:param nr_proc: number of processes to use. When larger than 1, order ds (DataFlow): input DataFlow.
of datapoints will be random. nr_proc (int): number of processes to use.
:param pipedir: a local directory where the pipes would be. pipedir (str): a local directory where the pipes should be put.
Useful if you're running on non-local FS such as NFS. Useful if you're running on non-local FS such as NFS or GlusterFS.
""" """
super(PrefetchDataZMQ, self).__init__(ds) super(PrefetchDataZMQ, self).__init__(ds)
try: try:
...@@ -185,10 +189,20 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -185,10 +189,20 @@ class PrefetchDataZMQ(ProxyDataFlow):
class PrefetchOnGPUs(PrefetchDataZMQ): class PrefetchOnGPUs(PrefetchDataZMQ):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES """
variable""" Prefetch with each process having its own ``CUDA_VISIBLE_DEVICES`` variable
mapped to one GPU.
"""
def __init__(self, ds, gpus, pipedir=None): def __init__(self, ds, gpus, pipedir=None):
"""
Args:
ds (DataFlow): input DataFlow.
gpus (list[int]): list of GPUs to use. Will also start this many
of processes.
pipedir (str): a local directory where the pipes should be put.
Useful if you're running on non-local FS such as NFS or GlusterFS.
"""
self.gpus = gpus self.gpus = gpus
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir) super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
......
...@@ -7,26 +7,21 @@ import numpy as np ...@@ -7,26 +7,21 @@ import numpy as np
import copy import copy
from six.moves import range from six.moves import range
from .base import DataFlow, RNGDataFlow from .base import DataFlow, RNGDataFlow
from ..utils.serialize import loads
__all__ = ['FakeData', 'DataFromQueue', 'DataFromList'] __all__ = ['FakeData', 'DataFromQueue', 'DataFromList']
try:
import zmq
except:
pass
else:
__all__.append('DataFromSocket')
class FakeData(RNGDataFlow): class FakeData(RNGDataFlow):
""" Generate fake fixed data of given shapes""" """ Generate fake data of given shapes"""
def __init__(self, shapes, size, random=True, dtype='float32'): def __init__(self, shapes, size=1000, random=True, dtype='float32'):
""" """
:param shapes: a list of lists/tuples Args:
:param size: size of this DataFlow shapes (list): a list of lists/tuples. Shapes of each component.
:param random: whether to randomly generate data every iteration. note size (int): size of this DataFlow.
that only generating the data could be time-consuming! random (bool): whether to randomly generate data every iteration.
Note that merely generating the data could sometimes be time-consuming!
dtype (str): data type.
""" """
super(FakeData, self).__init__() super(FakeData, self).__init__()
self.shapes = shapes self.shapes = shapes
...@@ -49,8 +44,11 @@ class FakeData(RNGDataFlow): ...@@ -49,8 +44,11 @@ class FakeData(RNGDataFlow):
class DataFromQueue(DataFlow): class DataFromQueue(DataFlow):
""" Produce data from a queue """ """ Produce data from a queue """
def __init__(self, queue): def __init__(self, queue):
"""
Args:
queue (queue): a queue with ``get()`` method.
"""
self.queue = queue self.queue = queue
def get_data(self): def get_data(self):
...@@ -62,6 +60,11 @@ class DataFromList(RNGDataFlow): ...@@ -62,6 +60,11 @@ class DataFromList(RNGDataFlow):
""" Produce data from a list""" """ Produce data from a list"""
def __init__(self, lst, shuffle=True): def __init__(self, lst, shuffle=True):
"""
Args:
lst (list): input list.
shuffle (bool): shuffle data.
"""
super(DataFromList, self).__init__() super(DataFromList, self).__init__()
self.lst = lst self.lst = lst
self.shuffle = shuffle self.shuffle = shuffle
...@@ -78,22 +81,3 @@ class DataFromList(RNGDataFlow): ...@@ -78,22 +81,3 @@ class DataFromList(RNGDataFlow):
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
yield self.lst[k] yield self.lst[k]
class DataFromSocket(DataFlow):
""" Produce data from a zmq socket"""
def __init__(self, socket_name):
self._name = socket_name
def get_data(self):
try:
ctx = zmq.Context()
socket = ctx.socket(zmq.PULL)
socket.bind(self._name)
while True:
dp = loads(socket.recv(copy=False))
yield dp
finally:
ctx.destroy(linger=0)
...@@ -19,6 +19,14 @@ from ..utils.serialize import dumps, loads ...@@ -19,6 +19,14 @@ from ..utils.serialize import dumps, loads
def serve_data(ds, addr): def serve_data(ds, addr):
"""
Serve the DataFlow on a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket.
Args:
ds (DataFlow): DataFlow to serve. Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr.
"""
ctx = zmq.Context() ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH) socket = ctx.socket(zmq.PUSH)
socket.set_hwm(10) socket.set_hwm(10)
...@@ -27,7 +35,7 @@ def serve_data(ds, addr): ...@@ -27,7 +35,7 @@ def serve_data(ds, addr):
try: try:
ds.reset_state() ds.reset_state()
logger.info("Serving data at {}".format(addr)) logger.info("Serving data at {}".format(addr))
# TODO print statistics here # TODO print statistics such as speed
while True: while True:
for dp in ds.get_data(): for dp in ds.get_data():
socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
...@@ -39,17 +47,25 @@ def serve_data(ds, addr): ...@@ -39,17 +47,25 @@ def serve_data(ds, addr):
class RemoteData(DataFlow): class RemoteData(DataFlow):
""" Produce data from a ZMQ socket. """
def __init__(self, addr): def __init__(self, addr):
self.ctx = zmq.Context() """
self.socket = self.ctx.socket(zmq.PULL) Args:
self.socket.set_hwm(10) addr (str): addr of the socket to connect to.
self.socket.connect(addr) """
self._addr = addr
def get_data(self): def get_data(self):
try:
ctx = zmq.Context()
socket = ctx.socket(zmq.PULL)
socket.connect(self._addr)
while True: while True:
dp = loads(self.socket.recv(copy=False)) dp = loads(socket.recv(copy=False))
yield dp yield dp
finally:
ctx.destroy(linger=0)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -12,7 +12,9 @@ except ImportError: ...@@ -12,7 +12,9 @@ except ImportError:
logger.warn_dependency('TFFuncMapper', 'tensorflow') logger.warn_dependency('TFFuncMapper', 'tensorflow')
__all__ = [] __all__ = []
else: else:
__all__ = ['TFFuncMapper'] __all__ = []
""" This file was deprecated """
class TFFuncMapper(ProxyDataFlow): class TFFuncMapper(ProxyDataFlow):
......
...@@ -138,7 +138,7 @@ class ModelFromMetaGraph(ModelDesc): ...@@ -138,7 +138,7 @@ class ModelFromMetaGraph(ModelDesc):
def __init__(self, filename): def __init__(self, filename):
""" """
Args: Args:
filename(str): file name of the saved meta graph. filename (str): file name of the saved meta graph.
""" """
tf.train.import_meta_graph(filename) tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys() all_coll = tf.get_default_graph().get_all_collection_keys()
......
...@@ -59,9 +59,9 @@ def Dropout(x, keep_prob=0.5, is_training=None): ...@@ -59,9 +59,9 @@ def Dropout(x, keep_prob=0.5, is_training=None):
Neural Networks from Overfitting <http://dl.acm.org/citation.cfm?id=2670313>`_. Neural Networks from Overfitting <http://dl.acm.org/citation.cfm?id=2670313>`_.
Args: Args:
keep_prob: the probability that each element is kept. It is only used keep_prob (float): the probability that each element is kept. It is only used
when is_training=True. when is_training=True.
is_training: If None, will use the current :class:`tensorpack.tfutils.TowerContext` is_training (bool): If None, will use the current :class:`tensorpack.tfutils.TowerContext`
to figure out. to figure out.
""" """
if is_training is None: if is_training is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment