Commit 7b91ba1d authored by Yuxin Wu's avatar Yuxin Wu

write some more comments in dataflow

parent bee6798b
......@@ -13,9 +13,10 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
class DataFlowTerminated(BaseException):
"""
An exception indicating that the DataFlow is unable to produce any more data:
calling :meth:`get_data` will not give a valid iterator any more.
In most DataFlow this will not be raised.
An exception indicating that the DataFlow is unable to produce any more
data, i.e. something wrong happened so that calling :meth:`get_data`
cannot give a valid iterator any more.
In most DataFlow this will never be raised.
"""
pass
......@@ -80,7 +81,9 @@ class RNGDataFlow(DataFlow):
class ProxyDataFlow(DataFlow):
""" Base class for DataFlow that proxies another"""
""" Base class for DataFlow that proxies another.
Every method is proxied to ``self.ds`` unless override by subclass.
"""
def __init__(self, ds):
"""
......@@ -90,9 +93,6 @@ class ProxyDataFlow(DataFlow):
self.ds = ds
def reset_state(self):
"""
Reset state of the proxied DataFlow.
"""
self.ds.reset_state()
def size(self):
......
......@@ -32,7 +32,7 @@ class TestDataSpeed(ProxyDataFlow):
self.test_size = size
def get_data(self):
""" Will start testing at the beginning, then produce data normally. """
""" Will run testing at the beginning, then produce data normally. """
self.start_test()
for dp in self.ds.get_data():
yield dp
......@@ -62,22 +62,23 @@ class BatchData(ProxyDataFlow):
Concat datapoints into batches.
It produces datapoints of the same number of components as ``ds``, but
each component has one new extra dimension of size ``batch_size``.
The new component can be a list of the original datapoints, or an ndarray
of the original datapoints.
A batch can be either a list of original components, or (by default)
a numpy array of original components.
"""
def __init__(self, ds, batch_size, remainder=False, use_list=False):
"""
Args:
ds (DataFlow): Its components must be either scalars or :class:`np.ndarray`.
Each component has to be of the same shape across datapoints.
ds (DataFlow): When ``use_list=False``, the components of ``ds``
must be either scalars or :class:`np.ndarray`, and
components has to have consistent shape across ``ds``.
batch_size(int): batch size
remainder (bool): whether to return the remaining data smaller than a batch_size.
If set True, it will possibly generates a data point of a smaller batch size.
Otherwise, all generated data are guranteed to have the same size.
use_list (bool): if True, it will run faster by producing a list
of datapoints instead of an ndarray of datapoints, avoiding an
extra copy.
remainder (bool): When the remaining datapoints in ``ds`` is not
enough to form a batch, whether or not to also produce the remaining
data as a smaller batch.
If set to False, all generated datapoints are guranteed to have the same batch size.
use_list (bool): if True, each component will contain a list
of datapoints instead of an numpy array of datapoints. This also avoids an extra copy.
"""
super(BatchData, self).__init__(ds)
if not remainder:
......@@ -152,9 +153,10 @@ class BatchDataByShape(BatchData):
datapoints of different shape, and batches will be formed from those who
have the same shape.
It is implemented by a dict{shape -> datapoints}.
Datapoints of uncommon shapes may never be enough to form a batch and
never get generated.
Note:
It is implemented by a dict{shape -> datapoints}.
Datapoints of uncommon shapes may never be enough to form a batch and
never get generated.
"""
def __init__(self, ds, batch_size, idx):
"""
......@@ -184,8 +186,8 @@ class BatchDataByShape(BatchData):
class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed size.
The iterator of the underlying DataFlow will be kept if not exhausted.
""" Generate data from another DataFlow, but with a fixed total count.
The iterator state of the underlying DataFlow will be kept if not exhausted.
"""
def __init__(self, ds, size):
"""
......@@ -220,19 +222,21 @@ class FixedSizeData(ProxyDataFlow):
class MapData(ProxyDataFlow):
""" Apply a mapper/filter on the DataFlow"""
"""
Apply a mapper/filter on the DataFlow.
Note:
1. Please make sure func doesn't modify the components
unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect.
"""
def __init__(self, ds, func):
"""
Args:
ds (DataFlow): input DataFlow
func (datapoint -> datapoint | None): takes a datapoint and returns a new
datapoint. Return None to discard this data point.
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
Note:
Please make sure func doesn't modify the components
unless you're certain it's safe.
datapoint. Return None to discard this datapoint.
"""
super(MapData, self).__init__(ds)
self.func = func
......@@ -245,20 +249,23 @@ class MapData(ProxyDataFlow):
class MapDataComponent(MapData):
""" Apply a mapper/filter on a datapoint component"""
"""
Apply a mapper/filter on a datapoint component.
Note:
1. This dataflow itself doesn't modify the datapoints.
But please make sure func doesn't modify the components
unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect.
"""
def __init__(self, ds, func, index=0):
"""
Args:
ds (DataFlow): input DataFlow.
func (TYPE -> TYPE|None): takes ``dp[index]``, returns a new value for ``dp[index]``.
return None to discard this datapoint.
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
index (int): index of the component.
Note:
This proxy itself doesn't modify the datapoints.
But please make sure func doesn't modify the components
unless you're certain it's safe.
"""
def f(dp):
r = func(dp[index])
......@@ -272,7 +279,8 @@ class MapDataComponent(MapData):
class RepeatedData(ProxyDataFlow):
""" Take data points from another DataFlow and produce them until
it's exhausted for certain amount of times.
it's exhausted for certain amount of times. i.e.:
dp1, dp2, .... dpn, dp1, dp2, ....dpn
"""
def __init__(self, ds, nr):
......@@ -306,9 +314,9 @@ class RepeatedData(ProxyDataFlow):
class RepeatedDataPoint(ProxyDataFlow):
""" Take data points from another DataFlow and produce them a
certain number of times dp1, ..., dp1, dp2, ..., dp2, ...
certain number of times. i.e.:
dp1, dp1, ..., dp1, dp2, ..., dp2, ...
"""
def __init__(self, ds, nr):
......@@ -408,7 +416,9 @@ class RandomMixData(RNGDataFlow):
class ConcatData(DataFlow):
"""
Concatenate several dataflows. Produce datapoints from them one by one.
Concatenate several DataFlow.
Produce datapoints from each DataFlow and go to the next when one
DataFlow is exhausted.
"""
def __init__(self, df_lists):
......@@ -439,9 +449,9 @@ class JoinData(DataFlow):
.. code-block:: none
dp1: [c1, c2]
dp2: [c3, c4]
join: [c1, c2, c3, c4]
df1 produces: [c1, c2]
df2 produces: [c3, c4]
joined: [c1, c2, c3, c4]
"""
def __init__(self, df_lists):
......@@ -495,9 +505,9 @@ def SelectComponent(ds, idxs):
.. code-block:: none
dp: [c1, c2, c3]
original df produces: [c1, c2, c3]
idxs: [2,1]
output dp: [c3, c2]
this df: [c3, c2]
"""
return MapData(ds, lambda dp: [dp[i] for i in idxs])
......@@ -561,7 +571,8 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
class CacheData(ProxyDataFlow):
"""
Cache a dataflow completely in memory.
Cache the first pass of a DataFlow completely in memory,
and produce from the cache thereafter.
"""
def __init__(self, ds, shuffle=False):
"""
......
......@@ -29,9 +29,9 @@ class HDF5Data(RNGDataFlow):
Zip data from different paths in an HDF5 file.
Warning:
The current implementation will load all data into memory.
The current implementation will load all data into memory. (TODO)
"""
# TODO lazy load
# TODO
def __init__(self, filename, data_paths, shuffle=True):
"""
......@@ -61,7 +61,8 @@ class HDF5Data(RNGDataFlow):
class LMDBData(RNGDataFlow):
""" Read a LMDB database and produce (k,v) pairs """
""" Read a LMDB database and produce (k,v) raw string pairs.
"""
def __init__(self, lmdb_path, shuffle=True, keys=None):
"""
Args:
......@@ -79,7 +80,7 @@ class LMDBData(RNGDataFlow):
self._lmdb_path = lmdb_path
self._shuffle = shuffle
self.open_lmdb()
self._open_lmdb()
self._size = self._txn.stat()['entries']
self._set_keys(keys)
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
......@@ -113,7 +114,7 @@ class LMDBData(RNGDataFlow):
else:
self.keys = keys
def open_lmdb(self):
def _open_lmdb(self):
self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=True,
......@@ -123,7 +124,7 @@ class LMDBData(RNGDataFlow):
def reset_state(self):
self._lmdb.close()
super(LMDBData, self).reset_state()
self.open_lmdb()
self._open_lmdb()
def size(self):
return self._size
......
......@@ -131,6 +131,7 @@ def layer_register(
if name is not None: # use scope
with tf.variable_scope(name) as scope:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
scope_name = re.sub('tower[0-9]+/', '', scope.name)
do_log_shape = log_shape and scope_name not in _LAYER_LOGGED
if do_log_shape:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment