Commit 17eca8ec authored by Yuxin Wu's avatar Yuxin Wu

keep_state option for FixedSizeData

parent fec86fec
......@@ -190,18 +190,31 @@ class BatchDataByShape(BatchData):
class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed total count.
The iterator state of the underlying DataFlow will be kept if not exhausted.
"""
def __init__(self, ds, size):
def __init__(self, ds, size, keep_state=True):
"""
Args:
ds (DataFlow): input dataflow
size (int): size
keep_state (bool): keep the iterator state of ``ds``
between calls to :meth:`get_data()`, so that the
next call will continue the previous iteration over ``ds``,
instead of reinitializing an iterator.
Examples:
.. code-block:: none
ds produces: 1, 2, 3, 4, 5; 1, 2, 3, 4, 5; ...
FixedSizeData(ds, 3, True): 1, 2, 3; 4, 5, 1; 2, 3, 4; ...
FixedSizeData(ds, 3, False): 1, 2, 3; 1, 2, 3; ...
FixedSizeData(ds, 6, False): 1, 2, 3, 4, 5, 1; 1, 2, 3, 4, 5, 1;...
"""
super(FixedSizeData, self).__init__(ds)
self._size = int(size)
self.itr = None
self._guard = DataFlowReentrantGuard()
self._keep = keep_state
def size(self):
return self._size
......@@ -221,6 +234,8 @@ class FixedSizeData(ProxyDataFlow):
cnt += 1
yield dp
if cnt == self._size:
if not self._keep:
self.itr = None
return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment