Commit 8c9c61d3 authored by Yuxin Wu's avatar Yuxin Wu

joindata

parent fe548cf0
......@@ -9,7 +9,7 @@ from .base import DataFlow, ProxyDataFlow
from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData', 'RandomMixData']
'MapDataComponent', 'RandomChooseData', 'RandomMixData', 'JoinData']
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
......@@ -218,3 +218,25 @@ class RandomMixData(DataFlow):
for k in idxs:
yield next(itrs[k])
class JoinData(DataFlow):
"""
Concatenate several dataflows
"""
def __init__(self, df_lists):
"""
df_lists: list of dataflow
"""
self.df_lists = df_lists
def reset_state(self):
for d in self.df_lists:
d.reset_state()
def size(self):
return sum([x.size() for x in self.df_lists])
def get_data(self):
for d in self.df_lists:
for dp in d.get_data():
yield dp
......@@ -2,12 +2,17 @@
# File: format.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import h5py
from ..utils import logger
from .base import DataFlow
import random
from six.moves import range
from ..utils import logger
from .base import DataFlow
try:
import h5py
except ImportError:
logger.error("Error in `import h5py`. HDF5Data cannot function.")
"""
Adapter for different data format.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment