Commit 58c2779f authored by Yuxin Wu's avatar Yuxin Wu

rename and docs

parent 2a4d248f
......@@ -21,8 +21,8 @@ from ..utils.serialize import loads, dumps
from ..utils import logger
from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs',
'MultiThreadPrefetchData']
__all__ = ['PrefetchData', 'MultiProcessPrefetchData',
'PrefetchDataZMQ', 'PrefetchOnGPUs', 'MultiThreadPrefetchData']
def _repeat_iter(get_itr):
......@@ -112,7 +112,7 @@ class _MultiProcessZMQDataFlow(DataFlow):
pass
class PrefetchData(ProxyDataFlow):
class MultiProcessPrefetchData(ProxyDataFlow):
"""
Prefetch data from a DataFlow using Python multiprocessing utilities.
It will fork the process calling :meth:`__init__`, collect datapoints from `ds` in each
......@@ -135,7 +135,7 @@ class PrefetchData(ProxyDataFlow):
class _Worker(mp.Process):
def __init__(self, ds, queue):
super(PrefetchData._Worker, self).__init__()
super(MultiProcessPrefetchData._Worker, self).__init__()
self.ds = ds
self.queue = queue
......@@ -153,7 +153,7 @@ class PrefetchData(ProxyDataFlow):
nr_prefetch (int): size of the queue to hold prefetched datapoints.
nr_proc (int): number of processes to use.
"""
super(PrefetchData, self).__init__(ds)
super(MultiProcessPrefetchData, self).__init__(ds)
try:
self._size = ds.size()
except NotImplementedError:
......@@ -163,11 +163,11 @@ class PrefetchData(ProxyDataFlow):
self._guard = DataFlowReentrantGuard()
if nr_proc > 1:
logger.info("[PrefetchData] Will fork a dataflow more than one times. "
logger.info("[MultiProcessPrefetchData] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchData._Worker(self.ds, self.queue)
self.procs = [MultiProcessPrefetchData._Worker(self.ds, self.queue)
for _ in range(self.nr_proc)]
ensure_proc_terminate(self.procs)
start_proc_mask_signal(self.procs)
......@@ -185,6 +185,9 @@ class PrefetchData(ProxyDataFlow):
pass
PrefetchData = MultiProcessPrefetchData
class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
"""
Prefetch data from a DataFlow using multiple processes, with ZeroMQ for
......@@ -329,6 +332,7 @@ class MultiThreadPrefetchData(DataFlow):
def __init__(self, get_df, queue):
super(MultiThreadPrefetchData._Worker, self).__init__()
self.df = get_df()
assert isinstance(self.df, DataFlow), self.df
self.queue = queue
self.daemon = True
......
......@@ -294,6 +294,8 @@ MultiProcessMapData = MultiProcessMapDataZMQ # alias
def _pool_map(data):
global SHARED_ARR, WORKER_ID, MAP_FUNC
res = MAP_FUNC(data)
if res is None:
return None
shared = np.reshape(SHARED_ARR, res.shape)
assert shared.dtype == res.dtype
shared[:] = res
......@@ -303,8 +305,8 @@ def _pool_map(data):
class MultiProcessMapDataComponentSharedArray(DataFlow):
"""
Similar to :class:`MapDataComponent`, but perform IPC by shared memory,
therefore more efficient. It requires `map_func` to always return
a numpy array of fixed shape and dtype, or None.
therefore more efficient when data (result of map_func) is large.
It requires `map_func` to always return a numpy array of fixed shape and dtype, or None.
"""
def __init__(self, ds, nr_proc, map_func, output_shape, output_dtype, index=0):
"""
......@@ -370,6 +372,8 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
res = self._pool.map_async(_pool_map, to_map)
for index in res.get():
if index is None:
continue
arr = np.reshape(self._shared_mem[index], self.output_shape)
dp = dps[index]
dp[self.index] = arr.copy()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment