Commit 5dfebc8d authored by Yuxin Wu's avatar Yuxin Wu

mapdatacomponent with shared memory

parent ca06ba07
......@@ -6,6 +6,9 @@ from __future__ import print_function
import weakref
import threading
from contextlib import contextmanager
import numpy as np
import ctypes
import copy
import multiprocessing as mp
import itertools
from six.moves import range, zip, queue
......@@ -25,7 +28,8 @@ from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs',
'ThreadedMapData', 'MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ']
'MultiProcessMapData', 'MultiProcessMapDataZMQ',
'MultiProcessMapDataComponentSharedArray']
def _repeat_iter(get_itr):
......@@ -589,6 +593,91 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
MultiProcessMapData = MultiProcessMapDataZMQ # alias
def _pool_map(data):
global SHARED_ARR, WORKER_ID, MAP_FUNC
res = MAP_FUNC(data)
shared = np.reshape(SHARED_ARR, res.shape)
assert shared.dtype == res.dtype
shared[:] = res
return WORKER_ID
class MultiProcessMapDataComponentSharedArray(DataFlow):
"""
Similar to :class:`MapDataComponent`, but perform IPC by shared memory,
therefore more efficient. It requires `map_func` to always return
a numpy array of fixed shape and dtype, or None.
"""
def __init__(self, ds, nr_proc, map_func, output_shape, output_dtype, index=0):
"""
Args:
ds (DataFlow): the dataflow to map on
nr_proc(int): number of processes
map_func (data component -> ndarray | None): the mapping function
output_shape (tuple): the shape of the output of map_func
output_dtype (np.dtype): the type of the output of map_func
index (int): the index of the datapoint component to map on.
"""
self.ds = ds
self.nr_proc = nr_proc
self.map_func = map_func
self.output_shape = output_shape
self.output_dtype = np.dtype(output_dtype).type
self.index = index
self._shared_mem = [self._create_shared_arr() for k in range(nr_proc)]
id_queue = mp.Queue()
for k in range(nr_proc):
id_queue.put(k)
def _init_pool(arrs, queue, map_func):
id = queue.get()
global SHARED_ARR, WORKER_ID, MAP_FUNC
SHARED_ARR = arrs[id]
WORKER_ID = id
MAP_FUNC = map_func
self._pool = mp.pool.Pool(
processes=nr_proc,
initializer=_init_pool,
initargs=(self._shared_mem, id_queue, map_func))
self._guard = DataFlowReentrantGuard()
def _create_shared_arr(self):
TYPE = {
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
np.uint8: ctypes.c_uint8,
np.int8: ctypes.c_int8,
np.int32: ctypes.c_int32,
}
ctype = TYPE[self.output_dtype]
arr = mp.RawArray(ctype, int(np.prod(self.output_shape)))
return arr
def size(self):
return self.ds.size()
def reset_state(self):
self.ds.reset_state()
def get_data(self):
ds_itr = _repeat_iter(self.ds.get_data)
with self._guard:
while True:
dps = []
for k in range(self.nr_proc):
dps.append(copy.copy(next(ds_itr)))
to_map = [x[self.index] for x in dps]
res = self._pool.map_async(_pool_map, to_map)
for index in res.get():
arr = np.reshape(self._shared_mem[index], self.output_shape)
dp = dps[index]
dp[self.index] = arr
yield dp
if __name__ == '__main__':
class Zero(DataFlow):
def __init__(self, size):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment