Commit 75925343 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

solution from issue 207 (#227)

* solution from issue 207

* remove copy

* remove old docs
parent 2fb8f750
......@@ -12,7 +12,7 @@ from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import logger, get_tqdm, get_rng
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
'MapDataComponent', 'RepeatedData', 'RandomChooseData',
'MapDataComponent', 'RepeatedData', 'RepeatedDataPoint', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'LocallyShuffleData', 'CacheData']
......@@ -295,6 +295,31 @@ class RepeatedData(ProxyDataFlow):
yield dp
class RepeatedDataPoint(ProxyDataFlow):
""" Take data points from another DataFlow and produce them a
certain number of times dp1, ..., dp1, dp2, ..., dp2, ...
"""
def __init__(self, ds, nr):
"""
Args:
ds (DataFlow): input DataFlow
nr (int): number of times to repeat each datapoint.
"""
self.nr = int(nr)
assert self.nr >= 1, self.nr
super(RepeatedDataPoint, self).__init__(ds)
def size(self):
return self.ds.size() * self.nr
def get_data(self):
for dp in self.ds.get_data():
for _ in range(self.nr):
yield dp
class RandomChooseData(RNGDataFlow):
"""
Randomly choose from several DataFlow.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment