Commit 9c30c49d authored by Yuxin Wu's avatar Yuxin Wu

use infinite loop in locallyshuffle

parent a761a839
...@@ -458,32 +458,31 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -458,32 +458,31 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def reset_state(self): def reset_state(self):
ProxyDataFlow.reset_state(self) ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self) RNGDataFlow.reset_state(self)
self.ds_itr = self.ds.get_data() self.ds_itr = RepeatedData(self.ds).get_data()
self.current_cnt = 0 self.current_cnt = 0
def _add_data(self):
dp = next(self.ds_itr)
for _ in range(self.nr_reuse):
self.q.append(dp)
def get_data(self): def get_data(self):
def add_next(): # fill queue
dp = next(self.ds_itr) while self.q.maxlen > len(self.q):
for _ in range(self.nr_reuse): self._add_data()
self.q.append(dp)
try: sz = self.size()
while self.q.maxlen > len(self.q): cnt = 0
add_next()
except StopIteration:
logger.error("LocallyShuffleData: cache_size is larger than the size of ds!")
while True: while True:
self.rng.shuffle(self.q) self.rng.shuffle(self.q)
for _ in range(self.q.maxlen): for _ in range(self.q.maxlen):
# the inner loop maintains the queue size (almost) unchanged
for _ in range(self.nr_reuse): for _ in range(self.nr_reuse):
yield self.q.popleft() yield self.q.popleft()
try: cnt += self.nr_reuse
add_next() if cnt >= sz:
except StopIteration:
# produce the rest and return
self.rng.shuffle(self.q)
for v in self.q:
yield v
return return
self._add_data()
class PrintData(ProxyDataFlow): class PrintData(ProxyDataFlow):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment