Commit 8c9c61d3 authored by Yuxin Wu's avatar Yuxin Wu

joindata

parent fe548cf0
...@@ -9,7 +9,7 @@ from .base import DataFlow, ProxyDataFlow ...@@ -9,7 +9,7 @@ from .base import DataFlow, ProxyDataFlow
from ..utils import * from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData', 'RandomMixData'] 'MapDataComponent', 'RandomChooseData', 'RandomMixData', 'JoinData']
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
...@@ -218,3 +218,25 @@ class RandomMixData(DataFlow): ...@@ -218,3 +218,25 @@ class RandomMixData(DataFlow):
for k in idxs: for k in idxs:
yield next(itrs[k]) yield next(itrs[k])
class JoinData(DataFlow):
"""
Concatenate several dataflows
"""
def __init__(self, df_lists):
"""
df_lists: list of dataflow
"""
self.df_lists = df_lists
def reset_state(self):
for d in self.df_lists:
d.reset_state()
def size(self):
return sum([x.size() for x in self.df_lists])
def get_data(self):
for d in self.df_lists:
for dp in d.get_data():
yield dp
...@@ -2,12 +2,17 @@ ...@@ -2,12 +2,17 @@
# File: format.py # File: format.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import h5py from ..utils import logger
from .base import DataFlow
import random import random
from six.moves import range from six.moves import range
from ..utils import logger try:
from .base import DataFlow import h5py
except ImportError:
logger.error("Error in `import h5py`. HDF5Data cannot function.")
""" """
Adapter for different data format. Adapter for different data format.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment