Commit 5dfebc8d authored by Yuxin Wu's avatar Yuxin Wu

mapdatacomponent with shared memory

parent ca06ba07
...@@ -6,6 +6,9 @@ from __future__ import print_function ...@@ -6,6 +6,9 @@ from __future__ import print_function
import weakref import weakref
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np
import ctypes
import copy
import multiprocessing as mp import multiprocessing as mp
import itertools import itertools
from six.moves import range, zip, queue from six.moves import range, zip, queue
...@@ -25,7 +28,8 @@ from ..utils.gpu import change_gpu ...@@ -25,7 +28,8 @@ from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs', __all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs',
'ThreadedMapData', 'MultiThreadMapData', 'ThreadedMapData', 'MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ'] 'MultiProcessMapData', 'MultiProcessMapDataZMQ',
'MultiProcessMapDataComponentSharedArray']
def _repeat_iter(get_itr): def _repeat_iter(get_itr):
...@@ -589,6 +593,91 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -589,6 +593,91 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
MultiProcessMapData = MultiProcessMapDataZMQ # alias MultiProcessMapData = MultiProcessMapDataZMQ # alias
def _pool_map(data):
global SHARED_ARR, WORKER_ID, MAP_FUNC
res = MAP_FUNC(data)
shared = np.reshape(SHARED_ARR, res.shape)
assert shared.dtype == res.dtype
shared[:] = res
return WORKER_ID
class MultiProcessMapDataComponentSharedArray(DataFlow):
"""
Similar to :class:`MapDataComponent`, but perform IPC by shared memory,
therefore more efficient. It requires `map_func` to always return
a numpy array of fixed shape and dtype, or None.
"""
def __init__(self, ds, nr_proc, map_func, output_shape, output_dtype, index=0):
"""
Args:
ds (DataFlow): the dataflow to map on
nr_proc(int): number of processes
map_func (data component -> ndarray | None): the mapping function
output_shape (tuple): the shape of the output of map_func
output_dtype (np.dtype): the type of the output of map_func
index (int): the index of the datapoint component to map on.
"""
self.ds = ds
self.nr_proc = nr_proc
self.map_func = map_func
self.output_shape = output_shape
self.output_dtype = np.dtype(output_dtype).type
self.index = index
self._shared_mem = [self._create_shared_arr() for k in range(nr_proc)]
id_queue = mp.Queue()
for k in range(nr_proc):
id_queue.put(k)
def _init_pool(arrs, queue, map_func):
id = queue.get()
global SHARED_ARR, WORKER_ID, MAP_FUNC
SHARED_ARR = arrs[id]
WORKER_ID = id
MAP_FUNC = map_func
self._pool = mp.pool.Pool(
processes=nr_proc,
initializer=_init_pool,
initargs=(self._shared_mem, id_queue, map_func))
self._guard = DataFlowReentrantGuard()
def _create_shared_arr(self):
TYPE = {
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
np.uint8: ctypes.c_uint8,
np.int8: ctypes.c_int8,
np.int32: ctypes.c_int32,
}
ctype = TYPE[self.output_dtype]
arr = mp.RawArray(ctype, int(np.prod(self.output_shape)))
return arr
def size(self):
return self.ds.size()
def reset_state(self):
self.ds.reset_state()
def get_data(self):
ds_itr = _repeat_iter(self.ds.get_data)
with self._guard:
while True:
dps = []
for k in range(self.nr_proc):
dps.append(copy.copy(next(ds_itr)))
to_map = [x[self.index] for x in dps]
res = self._pool.map_async(_pool_map, to_map)
for index in res.get():
arr = np.reshape(self._shared_mem[index], self.output_shape)
dp = dps[index]
dp[self.index] = arr
yield dp
if __name__ == '__main__': if __name__ == '__main__':
class Zero(DataFlow): class Zero(DataFlow):
def __init__(self, size): def __init__(self, size):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment