Commit 7b91ba1d authored by Yuxin Wu's avatar Yuxin Wu

write some more comments in dataflow

parent bee6798b
...@@ -13,9 +13,10 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated'] ...@@ -13,9 +13,10 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
class DataFlowTerminated(BaseException): class DataFlowTerminated(BaseException):
""" """
An exception indicating that the DataFlow is unable to produce any more data: An exception indicating that the DataFlow is unable to produce any more
calling :meth:`get_data` will not give a valid iterator any more. data, i.e. something wrong happened so that calling :meth:`get_data`
In most DataFlow this will not be raised. cannot give a valid iterator any more.
In most DataFlow this will never be raised.
""" """
pass pass
...@@ -80,7 +81,9 @@ class RNGDataFlow(DataFlow): ...@@ -80,7 +81,9 @@ class RNGDataFlow(DataFlow):
class ProxyDataFlow(DataFlow): class ProxyDataFlow(DataFlow):
""" Base class for DataFlow that proxies another""" """ Base class for DataFlow that proxies another.
Every method is proxied to ``self.ds`` unless override by subclass.
"""
def __init__(self, ds): def __init__(self, ds):
""" """
...@@ -90,9 +93,6 @@ class ProxyDataFlow(DataFlow): ...@@ -90,9 +93,6 @@ class ProxyDataFlow(DataFlow):
self.ds = ds self.ds = ds
def reset_state(self): def reset_state(self):
"""
Reset state of the proxied DataFlow.
"""
self.ds.reset_state() self.ds.reset_state()
def size(self): def size(self):
......
...@@ -32,7 +32,7 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -32,7 +32,7 @@ class TestDataSpeed(ProxyDataFlow):
self.test_size = size self.test_size = size
def get_data(self): def get_data(self):
""" Will start testing at the beginning, then produce data normally. """ """ Will run testing at the beginning, then produce data normally. """
self.start_test() self.start_test()
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield dp yield dp
...@@ -62,22 +62,23 @@ class BatchData(ProxyDataFlow): ...@@ -62,22 +62,23 @@ class BatchData(ProxyDataFlow):
Concat datapoints into batches. Concat datapoints into batches.
It produces datapoints of the same number of components as ``ds``, but It produces datapoints of the same number of components as ``ds``, but
each component has one new extra dimension of size ``batch_size``. each component has one new extra dimension of size ``batch_size``.
The new component can be a list of the original datapoints, or an ndarray A batch can be either a list of original components, or (by default)
of the original datapoints. a numpy array of original components.
""" """
def __init__(self, ds, batch_size, remainder=False, use_list=False): def __init__(self, ds, batch_size, remainder=False, use_list=False):
""" """
Args: Args:
ds (DataFlow): Its components must be either scalars or :class:`np.ndarray`. ds (DataFlow): When ``use_list=False``, the components of ``ds``
Each component has to be of the same shape across datapoints. must be either scalars or :class:`np.ndarray`, and
components has to have consistent shape across ``ds``.
batch_size(int): batch size batch_size(int): batch size
remainder (bool): whether to return the remaining data smaller than a batch_size. remainder (bool): When the remaining datapoints in ``ds`` is not
If set True, it will possibly generates a data point of a smaller batch size. enough to form a batch, whether or not to also produce the remaining
Otherwise, all generated data are guranteed to have the same size. data as a smaller batch.
use_list (bool): if True, it will run faster by producing a list If set to False, all generated datapoints are guranteed to have the same batch size.
of datapoints instead of an ndarray of datapoints, avoiding an use_list (bool): if True, each component will contain a list
extra copy. of datapoints instead of an numpy array of datapoints. This also avoids an extra copy.
""" """
super(BatchData, self).__init__(ds) super(BatchData, self).__init__(ds)
if not remainder: if not remainder:
...@@ -152,9 +153,10 @@ class BatchDataByShape(BatchData): ...@@ -152,9 +153,10 @@ class BatchDataByShape(BatchData):
datapoints of different shape, and batches will be formed from those who datapoints of different shape, and batches will be formed from those who
have the same shape. have the same shape.
It is implemented by a dict{shape -> datapoints}. Note:
Datapoints of uncommon shapes may never be enough to form a batch and It is implemented by a dict{shape -> datapoints}.
never get generated. Datapoints of uncommon shapes may never be enough to form a batch and
never get generated.
""" """
def __init__(self, ds, batch_size, idx): def __init__(self, ds, batch_size, idx):
""" """
...@@ -184,8 +186,8 @@ class BatchDataByShape(BatchData): ...@@ -184,8 +186,8 @@ class BatchDataByShape(BatchData):
class FixedSizeData(ProxyDataFlow): class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed size. """ Generate data from another DataFlow, but with a fixed total count.
The iterator of the underlying DataFlow will be kept if not exhausted. The iterator state of the underlying DataFlow will be kept if not exhausted.
""" """
def __init__(self, ds, size): def __init__(self, ds, size):
""" """
...@@ -220,19 +222,21 @@ class FixedSizeData(ProxyDataFlow): ...@@ -220,19 +222,21 @@ class FixedSizeData(ProxyDataFlow):
class MapData(ProxyDataFlow): class MapData(ProxyDataFlow):
""" Apply a mapper/filter on the DataFlow""" """
Apply a mapper/filter on the DataFlow.
Note:
1. Please make sure func doesn't modify the components
unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect.
"""
def __init__(self, ds, func): def __init__(self, ds, func):
""" """
Args: Args:
ds (DataFlow): input DataFlow ds (DataFlow): input DataFlow
func (datapoint -> datapoint | None): takes a datapoint and returns a new func (datapoint -> datapoint | None): takes a datapoint and returns a new
datapoint. Return None to discard this data point. datapoint. Return None to discard this datapoint.
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
Note:
Please make sure func doesn't modify the components
unless you're certain it's safe.
""" """
super(MapData, self).__init__(ds) super(MapData, self).__init__(ds)
self.func = func self.func = func
...@@ -245,20 +249,23 @@ class MapData(ProxyDataFlow): ...@@ -245,20 +249,23 @@ class MapData(ProxyDataFlow):
class MapDataComponent(MapData): class MapDataComponent(MapData):
""" Apply a mapper/filter on a datapoint component""" """
Apply a mapper/filter on a datapoint component.
Note:
1. This dataflow itself doesn't modify the datapoints.
But please make sure func doesn't modify the components
unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect.
"""
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
func (TYPE -> TYPE|None): takes ``dp[index]``, returns a new value for ``dp[index]``. func (TYPE -> TYPE|None): takes ``dp[index]``, returns a new value for ``dp[index]``.
return None to discard this datapoint. return None to discard this datapoint.
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
index (int): index of the component. index (int): index of the component.
Note:
This proxy itself doesn't modify the datapoints.
But please make sure func doesn't modify the components
unless you're certain it's safe.
""" """
def f(dp): def f(dp):
r = func(dp[index]) r = func(dp[index])
...@@ -272,7 +279,8 @@ class MapDataComponent(MapData): ...@@ -272,7 +279,8 @@ class MapDataComponent(MapData):
class RepeatedData(ProxyDataFlow): class RepeatedData(ProxyDataFlow):
""" Take data points from another DataFlow and produce them until """ Take data points from another DataFlow and produce them until
it's exhausted for certain amount of times. it's exhausted for certain amount of times. i.e.:
dp1, dp2, .... dpn, dp1, dp2, ....dpn
""" """
def __init__(self, ds, nr): def __init__(self, ds, nr):
...@@ -306,9 +314,9 @@ class RepeatedData(ProxyDataFlow): ...@@ -306,9 +314,9 @@ class RepeatedData(ProxyDataFlow):
class RepeatedDataPoint(ProxyDataFlow): class RepeatedDataPoint(ProxyDataFlow):
""" Take data points from another DataFlow and produce them a """ Take data points from another DataFlow and produce them a
certain number of times dp1, ..., dp1, dp2, ..., dp2, ... certain number of times. i.e.:
dp1, dp1, ..., dp1, dp2, ..., dp2, ...
""" """
def __init__(self, ds, nr): def __init__(self, ds, nr):
...@@ -408,7 +416,9 @@ class RandomMixData(RNGDataFlow): ...@@ -408,7 +416,9 @@ class RandomMixData(RNGDataFlow):
class ConcatData(DataFlow): class ConcatData(DataFlow):
""" """
Concatenate several dataflows. Produce datapoints from them one by one. Concatenate several DataFlow.
Produce datapoints from each DataFlow and go to the next when one
DataFlow is exhausted.
""" """
def __init__(self, df_lists): def __init__(self, df_lists):
...@@ -439,9 +449,9 @@ class JoinData(DataFlow): ...@@ -439,9 +449,9 @@ class JoinData(DataFlow):
.. code-block:: none .. code-block:: none
dp1: [c1, c2] df1 produces: [c1, c2]
dp2: [c3, c4] df2 produces: [c3, c4]
join: [c1, c2, c3, c4] joined: [c1, c2, c3, c4]
""" """
def __init__(self, df_lists): def __init__(self, df_lists):
...@@ -495,9 +505,9 @@ def SelectComponent(ds, idxs): ...@@ -495,9 +505,9 @@ def SelectComponent(ds, idxs):
.. code-block:: none .. code-block:: none
dp: [c1, c2, c3] original df produces: [c1, c2, c3]
idxs: [2,1] idxs: [2,1]
output dp: [c3, c2] this df: [c3, c2]
""" """
return MapData(ds, lambda dp: [dp[i] for i in idxs]) return MapData(ds, lambda dp: [dp[i] for i in idxs])
...@@ -561,7 +571,8 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -561,7 +571,8 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
class CacheData(ProxyDataFlow): class CacheData(ProxyDataFlow):
""" """
Cache a dataflow completely in memory. Cache the first pass of a DataFlow completely in memory,
and produce from the cache thereafter.
""" """
def __init__(self, ds, shuffle=False): def __init__(self, ds, shuffle=False):
""" """
......
...@@ -29,9 +29,9 @@ class HDF5Data(RNGDataFlow): ...@@ -29,9 +29,9 @@ class HDF5Data(RNGDataFlow):
Zip data from different paths in an HDF5 file. Zip data from different paths in an HDF5 file.
Warning: Warning:
The current implementation will load all data into memory. The current implementation will load all data into memory. (TODO)
""" """
# TODO lazy load # TODO
def __init__(self, filename, data_paths, shuffle=True): def __init__(self, filename, data_paths, shuffle=True):
""" """
...@@ -61,7 +61,8 @@ class HDF5Data(RNGDataFlow): ...@@ -61,7 +61,8 @@ class HDF5Data(RNGDataFlow):
class LMDBData(RNGDataFlow): class LMDBData(RNGDataFlow):
""" Read a LMDB database and produce (k,v) pairs """ """ Read a LMDB database and produce (k,v) raw string pairs.
"""
def __init__(self, lmdb_path, shuffle=True, keys=None): def __init__(self, lmdb_path, shuffle=True, keys=None):
""" """
Args: Args:
...@@ -79,7 +80,7 @@ class LMDBData(RNGDataFlow): ...@@ -79,7 +80,7 @@ class LMDBData(RNGDataFlow):
self._lmdb_path = lmdb_path self._lmdb_path = lmdb_path
self._shuffle = shuffle self._shuffle = shuffle
self.open_lmdb() self._open_lmdb()
self._size = self._txn.stat()['entries'] self._size = self._txn.stat()['entries']
self._set_keys(keys) self._set_keys(keys)
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path)) logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
...@@ -113,7 +114,7 @@ class LMDBData(RNGDataFlow): ...@@ -113,7 +114,7 @@ class LMDBData(RNGDataFlow):
else: else:
self.keys = keys self.keys = keys
def open_lmdb(self): def _open_lmdb(self):
self._lmdb = lmdb.open(self._lmdb_path, self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path), subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=True, readonly=True, lock=False, readahead=True,
...@@ -123,7 +124,7 @@ class LMDBData(RNGDataFlow): ...@@ -123,7 +124,7 @@ class LMDBData(RNGDataFlow):
def reset_state(self): def reset_state(self):
self._lmdb.close() self._lmdb.close()
super(LMDBData, self).reset_state() super(LMDBData, self).reset_state()
self.open_lmdb() self._open_lmdb()
def size(self): def size(self):
return self._size return self._size
......
...@@ -131,6 +131,7 @@ def layer_register( ...@@ -131,6 +131,7 @@ def layer_register(
if name is not None: # use scope if name is not None: # use scope
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
# this name is only used to surpress logging, doesn't hurt to do some heuristics
scope_name = re.sub('tower[0-9]+/', '', scope.name) scope_name = re.sub('tower[0-9]+/', '', scope.name)
do_log_shape = log_shape and scope_name not in _LAYER_LOGGED do_log_shape = log_shape and scope_name not in _LAYER_LOGGED
if do_log_shape: if do_log_shape:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment