Commit 17eca8ec authored by Yuxin Wu's avatar Yuxin Wu

keep_state option for FixedSizeData

parent fec86fec
...@@ -190,18 +190,31 @@ class BatchDataByShape(BatchData): ...@@ -190,18 +190,31 @@ class BatchDataByShape(BatchData):
class FixedSizeData(ProxyDataFlow): class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed total count. """ Generate data from another DataFlow, but with a fixed total count.
The iterator state of the underlying DataFlow will be kept if not exhausted.
""" """
def __init__(self, ds, size): def __init__(self, ds, size, keep_state=True):
""" """
Args: Args:
ds (DataFlow): input dataflow ds (DataFlow): input dataflow
size (int): size size (int): size
keep_state (bool): keep the iterator state of ``ds``
between calls to :meth:`get_data()`, so that the
next call will continue the previous iteration over ``ds``,
instead of reinitializing an iterator.
Examples:
.. code-block:: none
ds produces: 1, 2, 3, 4, 5; 1, 2, 3, 4, 5; ...
FixedSizeData(ds, 3, True): 1, 2, 3; 4, 5, 1; 2, 3, 4; ...
FixedSizeData(ds, 3, False): 1, 2, 3; 1, 2, 3; ...
FixedSizeData(ds, 6, False): 1, 2, 3, 4, 5, 1; 1, 2, 3, 4, 5, 1;...
""" """
super(FixedSizeData, self).__init__(ds) super(FixedSizeData, self).__init__(ds)
self._size = int(size) self._size = int(size)
self.itr = None self.itr = None
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
self._keep = keep_state
def size(self): def size(self):
return self._size return self._size
...@@ -221,6 +234,8 @@ class FixedSizeData(ProxyDataFlow): ...@@ -221,6 +234,8 @@ class FixedSizeData(ProxyDataFlow):
cnt += 1 cnt += 1
yield dp yield dp
if cnt == self._size: if cnt == self._size:
if not self._keep:
self.itr = None
return return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment