Commit 6f9f4cd9 authored by Yuxin Wu's avatar Yuxin Wu

locallyshuffle

parent 1126fa5c
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow'] __all__ = ['DataFlow', 'ProxyDataFlow']
......
...@@ -7,13 +7,13 @@ import copy ...@@ -7,13 +7,13 @@ import copy
import numpy as np import numpy as np
from collections import deque from collections import deque
from six.moves import range, map from six.moves import range, map
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import * from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData', 'RepeatedData', 'MapDataComponent', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent', 'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'DataFromQueue', 'LocallyShuffleData'] 'LocallyShuffleData']
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
...@@ -135,25 +135,6 @@ class RepeatedData(ProxyDataFlow): ...@@ -135,25 +135,6 @@ class RepeatedData(ProxyDataFlow):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield dp yield dp
class FakeData(RNGDataFlow):
""" Generate fake random data of given shapes"""
def __init__(self, shapes, size):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
"""
super(FakeData, self).__init__()
self.shapes = shapes
self._size = int(size)
def size(self):
return self._size
def get_data(self):
for _ in range(self._size):
yield [self.rng.random_sample(k).astype('float32') for k in self.shapes]
#yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(ProxyDataFlow): class MapData(ProxyDataFlow):
""" Apply map/filter a function on the datapoint""" """ Apply map/filter a function on the datapoint"""
def __init__(self, ds, func): def __init__(self, ds, func):
...@@ -323,24 +304,31 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -323,24 +304,31 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, cache_size): def __init__(self, ds, cache_size):
ProxyDataFlow.__init__(self, ds) ProxyDataFlow.__init__(self, ds)
RNGDataFlow.__init__(self) RNGDataFlow.__init__(self)
self.cache_size = cache_size self.q = deque(maxlen=cache_size)
self.q = deque(maxlen=self.cache_size) self.ds_wrap = RepeatedData(ds, -1)
self.ds_itr = self.ds_wrap.get_data()
self.current_cnt = 0
def reset_state(self): def reset_state(self):
ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self) RNGDataFlow.reset_state(self)
self.ds_wrap = RepeatedData(self.ds, -1)
self.ds_itr = self.ds_wrap.get_data()
self.current_cnt = 0
def get_data(self): def get_data(self):
# TODO for _ in range(self.q.maxlen - len(self.q)):
pass self.q.append(next(self.ds_itr))
cnt = 0
class DataFromQueue(DataFlow):
""" Provide data from a queue """
def __init__(self, queue):
self.queue = queue
def get_data(self):
while True: while True:
yield self.queue.get() self.rng.shuffle(self.q)
for _ in range(self.q.maxlen):
yield self.q.popleft()
cnt += 1
if cnt == self.size():
return
self.q.append(next(self.ds_itr))
def SelectComponent(ds, idxs): def SelectComponent(ds, idxs):
""" """
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: raw.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
from six.moves import range
from .base import DataFlow, RNGDataFlow
__all__ = ['FakeData', 'DataFromQueue', 'DataFromList']
class FakeData(RNGDataFlow):
""" Generate fake random data of given shapes"""
def __init__(self, shapes, size):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
"""
super(FakeData, self).__init__()
self.shapes = shapes
self._size = int(size)
def size(self):
return self._size
def get_data(self):
for _ in range(self._size):
yield [self.rng.random_sample(k).astype('float32') for k in self.shapes]
#yield [self.rng.random_sample(k) for k in self.shapes]
class DataFromQueue(DataFlow):
""" Produce data from a queue """
def __init__(self, queue):
self.queue = queue
def get_data(self):
while True:
yield self.queue.get()
class DataFromList(RNGDataFlow):
""" Produce data from a list"""
def __init__(self, lst, shuffle=True):
super(DataFromList, self).__init__()
self.lst = lst
self.shuffle = shuffle
def size(self):
return len(self.lst)
def get_data(self):
if not self.shuffle:
for k in self.lst:
yield k
else:
idxs = self.rng.shuffle(np.arange(len(self.lst)))
for k in idxs:
yield self.lst[k]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment