Commit 8403a009 authored by Yuxin Wu's avatar Yuxin Wu

nr_reuse in locallyshuffledata

parent 53291500
...@@ -311,9 +311,15 @@ class JoinData(DataFlow): ...@@ -311,9 +311,15 @@ class JoinData(DataFlow):
del itr del itr
class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, cache_size): def __init__(self, ds, cache_size, nr_reuse=1):
"""
Cache a number of datapoints and shuffle them.
:param cache_size: size of the cache
:param nr_reuse: reuse each datapoints several times
"""
ProxyDataFlow.__init__(self, ds) ProxyDataFlow.__init__(self, ds)
self.q = deque(maxlen=cache_size) self.q = deque(maxlen=cache_size)
self.nr_reuse = nr_reuse
def reset_state(self): def reset_state(self):
ProxyDataFlow.reset_state(self) ProxyDataFlow.reset_state(self)
...@@ -322,17 +328,22 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -322,17 +328,22 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
self.current_cnt = 0 self.current_cnt = 0
def get_data(self): def get_data(self):
def add_next():
dp = next(self.ds_itr)
for _ in range(self.nr_reuse):
self.q.append(dp)
for _ in range(self.q.maxlen - len(self.q)): for _ in range(self.q.maxlen - len(self.q)):
try: try:
self.q.append(next(self.ds_itr)) add_next()
except StopIteration: except StopIteration:
logger.error("LocallyShuffleData: cache_size is larger than the size of ds!") logger.error("LocallyShuffleData: cache_size is larger than the size of ds!")
while True: while True:
self.rng.shuffle(self.q) self.rng.shuffle(self.q)
for _ in range(self.q.maxlen): for _ in range(self.q.maxlen):
for _ in range(self.nr_reuse):
yield self.q.popleft() yield self.q.popleft()
try: try:
self.q.append(next(self.ds_itr)) add_next()
except StopIteration: except StopIteration:
# produce the rest and return # produce the rest and return
self.rng.shuffle(self.q) self.rng.shuffle(self.q)
...@@ -341,7 +352,6 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -341,7 +352,6 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
return return
def SelectComponent(ds, idxs): def SelectComponent(ds, idxs):
""" """
:param ds: a :mod:`DataFlow` instance :param ds: a :mod:`DataFlow` instance
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment