Commit 807296b3 authored by Yuxin Wu's avatar Yuxin Wu

fix locallyshuffledata

parent 3fcd3b57
...@@ -318,24 +318,30 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -318,24 +318,30 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def reset_state(self): def reset_state(self):
ProxyDataFlow.reset_state(self) ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self) RNGDataFlow.reset_state(self)
self.ds_wrap = RepeatedData(self.ds, -1) self.ds_itr = self.ds.get_data()
self.ds_itr = self.ds_wrap.get_data()
self.current_cnt = 0 self.current_cnt = 0
def get_data(self): def get_data(self):
for _ in range(self.q.maxlen - len(self.q)): for _ in range(self.q.maxlen - len(self.q)):
self.q.append(next(self.ds_itr)) try:
cnt = 0 self.q.append(next(self.ds_itr))
except StopIteration:
logger.error("LocallyShuffleData: cache_size is larger than the size of ds!")
while True: while True:
self.rng.shuffle(self.q) self.rng.shuffle(self.q)
for _ in range(self.q.maxlen): for _ in range(self.q.maxlen):
yield self.q.popleft() yield self.q.popleft()
self.q.append(next(self.ds_itr)) try:
cnt += 1 self.q.append(next(self.ds_itr))
if cnt == self.size(): except StopIteration:
# produce the rest and return
self.rng.shuffle(self.q)
for v in self.q:
yield v
return return
def SelectComponent(ds, idxs): def SelectComponent(ds, idxs):
""" """
:param ds: a :mod:`DataFlow` instance :param ds: a :mod:`DataFlow` instance
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment