Commit 6f9f4cd9 authored by Yuxin Wu's avatar Yuxin Wu

locallyshuffle

parent 1126fa5c
......@@ -5,6 +5,7 @@
from abc import abstractmethod, ABCMeta
from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow']
......
......@@ -7,13 +7,13 @@ import copy
import numpy as np
from collections import deque
from six.moves import range, map
from .base import DataFlow, ProxyDataFlow
from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
__all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'DataFromQueue', 'LocallyShuffleData']
'LocallyShuffleData']
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
......@@ -135,25 +135,6 @@ class RepeatedData(ProxyDataFlow):
for dp in self.ds.get_data():
yield dp
class FakeData(RNGDataFlow):
""" Generate fake random data of given shapes"""
def __init__(self, shapes, size):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
"""
super(FakeData, self).__init__()
self.shapes = shapes
self._size = int(size)
def size(self):
return self._size
def get_data(self):
for _ in range(self._size):
yield [self.rng.random_sample(k).astype('float32') for k in self.shapes]
#yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(ProxyDataFlow):
""" Apply map/filter a function on the datapoint"""
def __init__(self, ds, func):
......@@ -323,24 +304,31 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, cache_size):
ProxyDataFlow.__init__(self, ds)
RNGDataFlow.__init__(self)
self.cache_size = cache_size
self.q = deque(maxlen=self.cache_size)
self.q = deque(maxlen=cache_size)
self.ds_wrap = RepeatedData(ds, -1)
self.ds_itr = self.ds_wrap.get_data()
self.current_cnt = 0
def reset_state(self):
ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self)
self.ds_wrap = RepeatedData(self.ds, -1)
self.ds_itr = self.ds_wrap.get_data()
self.current_cnt = 0
def get_data(self):
# TODO
pass
class DataFromQueue(DataFlow):
""" Provide data from a queue """
def __init__(self, queue):
self.queue = queue
def get_data(self):
for _ in range(self.q.maxlen - len(self.q)):
self.q.append(next(self.ds_itr))
cnt = 0
while True:
yield self.queue.get()
self.rng.shuffle(self.q)
for _ in range(self.q.maxlen):
yield self.q.popleft()
cnt += 1
if cnt == self.size():
return
self.q.append(next(self.ds_itr))
def SelectComponent(ds, idxs):
"""
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: raw.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
from six.moves import range
from .base import DataFlow, RNGDataFlow
__all__ = ['FakeData', 'DataFromQueue', 'DataFromList']
class FakeData(RNGDataFlow):
""" Generate fake random data of given shapes"""
def __init__(self, shapes, size):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
"""
super(FakeData, self).__init__()
self.shapes = shapes
self._size = int(size)
def size(self):
return self._size
def get_data(self):
for _ in range(self._size):
yield [self.rng.random_sample(k).astype('float32') for k in self.shapes]
#yield [self.rng.random_sample(k) for k in self.shapes]
class DataFromQueue(DataFlow):
""" Produce data from a queue """
def __init__(self, queue):
self.queue = queue
def get_data(self):
while True:
yield self.queue.get()
class DataFromList(RNGDataFlow):
""" Produce data from a list"""
def __init__(self, lst, shuffle=True):
super(DataFromList, self).__init__()
self.lst = lst
self.shuffle = shuffle
def size(self):
return len(self.lst)
def get_data(self):
if not self.shuffle:
for k in self.lst:
yield k
else:
idxs = self.rng.shuffle(np.arange(len(self.lst)))
for k in idxs:
yield self.lst[k]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment