Commit 58c2779f authored by Yuxin Wu's avatar Yuxin Wu

rename and docs

parent 2a4d248f
...@@ -21,8 +21,8 @@ from ..utils.serialize import loads, dumps ...@@ -21,8 +21,8 @@ from ..utils.serialize import loads, dumps
from ..utils import logger from ..utils import logger
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs', __all__ = ['PrefetchData', 'MultiProcessPrefetchData',
'MultiThreadPrefetchData'] 'PrefetchDataZMQ', 'PrefetchOnGPUs', 'MultiThreadPrefetchData']
def _repeat_iter(get_itr): def _repeat_iter(get_itr):
...@@ -112,7 +112,7 @@ class _MultiProcessZMQDataFlow(DataFlow): ...@@ -112,7 +112,7 @@ class _MultiProcessZMQDataFlow(DataFlow):
pass pass
class PrefetchData(ProxyDataFlow): class MultiProcessPrefetchData(ProxyDataFlow):
""" """
Prefetch data from a DataFlow using Python multiprocessing utilities. Prefetch data from a DataFlow using Python multiprocessing utilities.
It will fork the process calling :meth:`__init__`, collect datapoints from `ds` in each It will fork the process calling :meth:`__init__`, collect datapoints from `ds` in each
...@@ -135,7 +135,7 @@ class PrefetchData(ProxyDataFlow): ...@@ -135,7 +135,7 @@ class PrefetchData(ProxyDataFlow):
class _Worker(mp.Process): class _Worker(mp.Process):
def __init__(self, ds, queue): def __init__(self, ds, queue):
super(PrefetchData._Worker, self).__init__() super(MultiProcessPrefetchData._Worker, self).__init__()
self.ds = ds self.ds = ds
self.queue = queue self.queue = queue
...@@ -153,7 +153,7 @@ class PrefetchData(ProxyDataFlow): ...@@ -153,7 +153,7 @@ class PrefetchData(ProxyDataFlow):
nr_prefetch (int): size of the queue to hold prefetched datapoints. nr_prefetch (int): size of the queue to hold prefetched datapoints.
nr_proc (int): number of processes to use. nr_proc (int): number of processes to use.
""" """
super(PrefetchData, self).__init__(ds) super(MultiProcessPrefetchData, self).__init__(ds)
try: try:
self._size = ds.size() self._size = ds.size()
except NotImplementedError: except NotImplementedError:
...@@ -163,11 +163,11 @@ class PrefetchData(ProxyDataFlow): ...@@ -163,11 +163,11 @@ class PrefetchData(ProxyDataFlow):
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
if nr_proc > 1: if nr_proc > 1:
logger.info("[PrefetchData] Will fork a dataflow more than one times. " logger.info("[MultiProcessPrefetchData] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.") "This assumes the datapoints are i.i.d.")
self.queue = mp.Queue(self.nr_prefetch) self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchData._Worker(self.ds, self.queue) self.procs = [MultiProcessPrefetchData._Worker(self.ds, self.queue)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
ensure_proc_terminate(self.procs) ensure_proc_terminate(self.procs)
start_proc_mask_signal(self.procs) start_proc_mask_signal(self.procs)
...@@ -185,6 +185,9 @@ class PrefetchData(ProxyDataFlow): ...@@ -185,6 +185,9 @@ class PrefetchData(ProxyDataFlow):
pass pass
PrefetchData = MultiProcessPrefetchData
class PrefetchDataZMQ(_MultiProcessZMQDataFlow): class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
""" """
Prefetch data from a DataFlow using multiple processes, with ZeroMQ for Prefetch data from a DataFlow using multiple processes, with ZeroMQ for
...@@ -329,6 +332,7 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -329,6 +332,7 @@ class MultiThreadPrefetchData(DataFlow):
def __init__(self, get_df, queue): def __init__(self, get_df, queue):
super(MultiThreadPrefetchData._Worker, self).__init__() super(MultiThreadPrefetchData._Worker, self).__init__()
self.df = get_df() self.df = get_df()
assert isinstance(self.df, DataFlow), self.df
self.queue = queue self.queue = queue
self.daemon = True self.daemon = True
......
...@@ -294,6 +294,8 @@ MultiProcessMapData = MultiProcessMapDataZMQ # alias ...@@ -294,6 +294,8 @@ MultiProcessMapData = MultiProcessMapDataZMQ # alias
def _pool_map(data): def _pool_map(data):
global SHARED_ARR, WORKER_ID, MAP_FUNC global SHARED_ARR, WORKER_ID, MAP_FUNC
res = MAP_FUNC(data) res = MAP_FUNC(data)
if res is None:
return None
shared = np.reshape(SHARED_ARR, res.shape) shared = np.reshape(SHARED_ARR, res.shape)
assert shared.dtype == res.dtype assert shared.dtype == res.dtype
shared[:] = res shared[:] = res
...@@ -303,8 +305,8 @@ def _pool_map(data): ...@@ -303,8 +305,8 @@ def _pool_map(data):
class MultiProcessMapDataComponentSharedArray(DataFlow): class MultiProcessMapDataComponentSharedArray(DataFlow):
""" """
Similar to :class:`MapDataComponent`, but perform IPC by shared memory, Similar to :class:`MapDataComponent`, but perform IPC by shared memory,
therefore more efficient. It requires `map_func` to always return therefore more efficient when data (result of map_func) is large.
a numpy array of fixed shape and dtype, or None. It requires `map_func` to always return a numpy array of fixed shape and dtype, or None.
""" """
def __init__(self, ds, nr_proc, map_func, output_shape, output_dtype, index=0): def __init__(self, ds, nr_proc, map_func, output_shape, output_dtype, index=0):
""" """
...@@ -370,6 +372,8 @@ class MultiProcessMapDataComponentSharedArray(DataFlow): ...@@ -370,6 +372,8 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
res = self._pool.map_async(_pool_map, to_map) res = self._pool.map_async(_pool_map, to_map)
for index in res.get(): for index in res.get():
if index is None:
continue
arr = np.reshape(self._shared_mem[index], self.output_shape) arr = np.reshape(self._shared_mem[index], self.output_shape)
dp = dps[index] dp = dps[index]
dp[self.index] = arr.copy() dp[self.index] = arr.copy()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment