Commit 4d64098a authored by Yuxin Wu's avatar Yuxin Wu

selectcomponent

parent 27ea2836
...@@ -65,7 +65,9 @@ class BatchData(ProxyDataFlow): ...@@ -65,7 +65,9 @@ class BatchData(ProxyDataFlow):
return result return result
class FixedSizeData(ProxyDataFlow): class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed epoch size""" """ Generate data from another DataFlow, but with a fixed epoch size.
The state of the underlying DataFlow is maintained among each epoch.
"""
def __init__(self, ds, size): def __init__(self, ds, size):
""" """
:param ds: a :mod:`DataFlow` to produce data :param ds: a :mod:`DataFlow` to produce data
...@@ -165,7 +167,7 @@ class MapDataComponent(ProxyDataFlow): ...@@ -165,7 +167,7 @@ class MapDataComponent(ProxyDataFlow):
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
""" """
:param ds: a :mod:`DataFlow` instance. :param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint dp[index], returns a :param func: a function that takes a datapoint component dp[index], returns a
new value of dp[index]. return None to skip this datapoint. new value of dp[index]. return None to skip this datapoint.
""" """
super(MapDataComponent, self).__init__(ds) super(MapDataComponent, self).__init__(ds)
...@@ -269,3 +271,21 @@ class JoinData(DataFlow): ...@@ -269,3 +271,21 @@ class JoinData(DataFlow):
for dp in d.get_data(): for dp in d.get_data():
yield dp yield dp
class SelectComponent(ProxyDataFlow):
"""
Select component from a datapoint.
"""
def __init__(self, ds, idxs):
"""
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
"""
super(SelectComponent, self).__init__(ds)
self.idxs = idxs
def get_data(self):
for dp in self.ds.get_data():
newdp = []
for idx in self.idxs:
newdp.append(dp[idx])
yield newdp
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import multiprocessing import multiprocessing
from six.moves import range
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils.concurrency import ensure_procs_terminate from ..utils.concurrency import ensure_procs_terminate
...@@ -49,12 +50,9 @@ class PrefetchData(ProxyDataFlow): ...@@ -49,12 +50,9 @@ class PrefetchData(ProxyDataFlow):
def get_data(self): def get_data(self):
tot_cnt = 0 tot_cnt = 0
while True: for _ in range(tot_cnt):
dp = self.queue.get() dp = self.queue.get()
yield dp yield dp
tot_cnt += 1
if tot_cnt == self._size:
break
def __del__(self): def __del__(self):
self.queue.close() self.queue.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment