Commit 68e8d9eb authored by Yuxin Wu's avatar Yuxin Wu

MultiProcessMapData with strict (#414)

parent be3a07a1
Bug Reports/Feature Requests/Usage Questions Only:
Bug Reports: PLEASE always include
Bug reports or other problems with code: PLEASE always include
1. What you did. (command you run and changes you made if using examples; post or describe your code if not)
2. What you observed, e.g. logs.
2. What you observed, e.g. as much as logs possible.
3. What you expected, if not obvious.
4. Your environment (TF version, cudnn version, number & type of GPUs), if it matters.
5. About efficiency, PLEASE first read http://tensorpack.readthedocs.io/en/latest/tutorial/performance-tuning.html
......@@ -14,10 +14,4 @@ Feature Requests:
It may not have to be added to tensorpack unless you have a good reason.
3. Note that we don't implement papers at others' requests.
Usage Questions, e.g.:
"How do I do [this specific thing] in tensorpack?"
"Why certain examples need to be written in this way?"
We don't answer general machine learning questions like:
"I want to do [this machine learning task]. What specific things do I need to do?"
You can also use gitter (https://gitter.im/tensorpack/users) for more casual discussions.
......@@ -67,20 +67,11 @@ def _zmq_catch_error(name):
class _MultiProcessZMQDataFlow(DataFlow):
def __init__(self, ds):
def __init__(self):
assert os.name != 'nt', "ZMQ IPC doesn't support windows!"
self._reset_done = False
self._procs = []
self.ds = ds
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
def size(self):
return self.ds.size()
def reset_state(self):
"""
All forked dataflows are reset **once and only once** in spawned processes.
......@@ -265,10 +256,17 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
if nr_proc > 1:
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
def _recv(self):
return loads(self.socket.recv(copy=False).bytes)
def size(self):
return self.ds.size()
def get_data(self):
with self._guard, _zmq_catch_error('PrefetchDataZMQ'):
for k in itertools.count():
......@@ -311,7 +309,59 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
proc.start()
class MultiThreadMapData(ProxyDataFlow):
class _ParallelMapData(ProxyDataFlow):
def __init__(self, ds, buffer_size):
super(_ParallelMapData, self).__init__(ds)
assert buffer_size > 0, buffer_size
self._buffer_size = buffer_size
def _recv(self):
pass
def _send(self, dp):
pass
def _recv_filter_none(self):
ret = self._recv()
assert ret is not None, \
"[{}] Map function cannot return None when strict mode is used.".format(type(self).__name__)
return ret
def _fill_buffer(self):
try:
for _ in range(self._buffer_size):
dp = next(self._iter)
self._send(dp)
except StopIteration:
logger.error(
"[{}] buffer_size cannot be larger than the size of the DataFlow!".format(type(self).__name__))
raise
def get_data_non_strict(self):
for dp in self._iter:
self._send(dp)
yield self._recv()
self._iter = self.ds.get_data() # refresh
for _ in range(self._buffer_size):
self._send(next(self._iter))
yield self._recv()
def get_data_strict(self):
for dp in self._iter:
self._send(dp)
yield self._recv_filter_none()
self._iter = self.ds.get_data() # refresh
# first clear the buffer, then fill
for k in range(self._buffer_size):
dp = self._recv_filter_none()
if k == self._buffer_size - 1:
self._fill_buffer()
yield dp
class MultiThreadMapData(_ParallelMapData):
"""
Same as :class:`MapData`, but start threads to run the mapping function.
This is useful when the mapping function is the bottleneck, but you don't
......@@ -367,11 +417,10 @@ class MultiThreadMapData(ProxyDataFlow):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
super(MultiThreadMapData, self).__init__(ds)
super(MultiThreadMapData, self).__init__(ds, buffer_size)
self._strict = strict
self.nr_thread = nr_thread
self.buffer_size = buffer_size
self.map_func = map_func
self._threads = []
self._evt = None
......@@ -398,43 +447,20 @@ class MultiThreadMapData(ProxyDataFlow):
# only call once, to ensure inq+outq has a total of buffer_size elements
self._fill_buffer()
def _fill_buffer(self):
n = self.buffer_size - self._in_queue.qsize() - self._out_queue.qsize()
assert n >= 0, n
if n == 0:
return
try:
for _ in range(n):
self._in_queue.put(next(self._iter))
except StopIteration:
logger.error("[MultiThreadMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise
def _recv(self):
ret = self._out_queue.get()
if ret is None:
assert not self._strict, \
"[MultiThreadMapData] Map function cannot return None when strict mode is used."
return ret
return self._out_queue.get()
def get_data(self):
with self._guard:
for dp in self._iter:
def _send(self, dp):
self._in_queue.put(dp)
yield self._recv()
self._iter = self.ds.get_data()
def get_data(self):
with self._guard:
if self._strict:
# first call get() to clear the queues, then fill
for k in range(self.buffer_size):
dp = self._recv()
if k == self.buffer_size - 1:
self._fill_buffer()
for dp in self.get_data_strict():
yield dp
else:
for _ in range(self.buffer_size):
self._in_queue.put(next(self._iter))
yield self._recv()
for dp in self.get_data_non_strict():
yield dp
def __del__(self):
if self._evt is not None:
......@@ -447,10 +473,20 @@ class MultiThreadMapData(ProxyDataFlow):
ThreadedMapData = MultiThreadMapData
class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
"""
Same as :class:`MapData`, but start processes to run the mapping function,
and communicate with ZeroMQ pipe.
Note:
1. Processes run in parallel and can take different time to run the
mapping function. Therefore the order of datapoints won't be
preserved, and datapoints from one pass of `df.get_data()` might get
mixed with datapoints from the next pass.
You can use **strict mode**, where `MultiProcessMapData.get_data()`
is guranteed to produce the exact set which `df.get_data()`
produces. Although the order of data still isn't preserved.
"""
class _Worker(mp.Process):
def __init__(self, identity, map_func, pipename, hwm):
......@@ -472,30 +508,32 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
dp = self.map_func(dp)
socket.send(dumps(dp), copy=False)
def __init__(self, ds, nr_proc, map_func, buffer_size=200):
def __init__(self, ds, nr_proc, map_func, buffer_size=200, strict=False):
"""
Args:
ds (DataFlow): the dataflow to map
nr_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
super(MultiProcessMapDataZMQ, self).__init__(ds)
_ParallelMapData.__init__(self, ds, buffer_size)
_MultiProcessZMQDataFlow.__init__(self)
self.nr_proc = nr_proc
self.map_func = map_func
self.buffer_size = buffer_size
self._strict = strict
self._procs = []
self._guard = DataFlowReentrantGuard()
def _reset_once(self):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.set_hwm(self.buffer_size * 2)
self.socket.set_hwm(self._buffer_size * 2)
pipename = _get_pipe_name('dataflow-map')
_bind_guard(self.socket, pipename)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)]
worker_hwm = int(self.buffer_size * 2 // self.nr_proc)
worker_hwm = int(self._buffer_size * 2 // self.nr_proc)
self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)]
......@@ -507,14 +545,8 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
self._start_processes()
self._fill_buffer()
def _fill_buffer(self):
# Filling the buffer.
try:
for _ in range(self.buffer_size):
self._send(next(self._iter))
except StopIteration:
logger.error("[MultiProcessMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise
def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self)
def _send(self, dp):
# round-robin assignment
......@@ -529,14 +561,12 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
def get_data(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'):
for dp in self._iter:
self._send(dp)
yield self._recv()
self._iter = self.ds.get_data() # refresh
for _ in range(self.buffer_size):
self._send(next(self._iter))
yield self._recv()
if self._strict:
for dp in self.get_data_strict():
yield dp
else:
for dp in self.get_data_non_strict():
yield dp
MultiProcessMapData = MultiProcessMapDataZMQ # alias
......@@ -549,13 +579,13 @@ if __name__ == '__main__':
def get_data(self):
for k in range(self._size):
yield [0]
yield [k]
def size(self):
return self._size
ds = Zero(300)
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1])
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1], strict=True)
ds.reset_state()
for k in ds.get_data():
print("Bang!", k)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment