Commit 68e8d9eb authored by Yuxin Wu's avatar Yuxin Wu

MultiProcessMapData with strict (#414)

parent be3a07a1
Bug Reports/Feature Requests/Usage Questions Only: Bug Reports/Feature Requests/Usage Questions Only:
Bug Reports: PLEASE always include Bug reports or other problems with code: PLEASE always include
1. What you did. (command you run and changes you made if using examples; post or describe your code if not) 1. What you did. (command you run and changes you made if using examples; post or describe your code if not)
2. What you observed, e.g. logs. 2. What you observed, e.g. as much as logs possible.
3. What you expected, if not obvious. 3. What you expected, if not obvious.
4. Your environment (TF version, cudnn version, number & type of GPUs), if it matters. 4. Your environment (TF version, cudnn version, number & type of GPUs), if it matters.
5. About efficiency, PLEASE first read http://tensorpack.readthedocs.io/en/latest/tutorial/performance-tuning.html 5. About efficiency, PLEASE first read http://tensorpack.readthedocs.io/en/latest/tutorial/performance-tuning.html
...@@ -14,10 +14,4 @@ Feature Requests: ...@@ -14,10 +14,4 @@ Feature Requests:
It may not have to be added to tensorpack unless you have a good reason. It may not have to be added to tensorpack unless you have a good reason.
3. Note that we don't implement papers at others' requests. 3. Note that we don't implement papers at others' requests.
Usage Questions, e.g.:
"How do I do [this specific thing] in tensorpack?"
"Why certain examples need to be written in this way?"
We don't answer general machine learning questions like:
"I want to do [this machine learning task]. What specific things do I need to do?"
You can also use gitter (https://gitter.im/tensorpack/users) for more casual discussions. You can also use gitter (https://gitter.im/tensorpack/users) for more casual discussions.
...@@ -67,20 +67,11 @@ def _zmq_catch_error(name): ...@@ -67,20 +67,11 @@ def _zmq_catch_error(name):
class _MultiProcessZMQDataFlow(DataFlow): class _MultiProcessZMQDataFlow(DataFlow):
def __init__(self, ds): def __init__(self):
assert os.name != 'nt', "ZMQ IPC doesn't support windows!" assert os.name != 'nt', "ZMQ IPC doesn't support windows!"
self._reset_done = False self._reset_done = False
self._procs = [] self._procs = []
self.ds = ds
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
def size(self):
return self.ds.size()
def reset_state(self): def reset_state(self):
""" """
All forked dataflows are reset **once and only once** in spawned processes. All forked dataflows are reset **once and only once** in spawned processes.
...@@ -265,10 +256,17 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -265,10 +256,17 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
if nr_proc > 1: if nr_proc > 1:
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. " logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.") "This assumes the datapoints are i.i.d.")
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1
def _recv(self): def _recv(self):
return loads(self.socket.recv(copy=False).bytes) return loads(self.socket.recv(copy=False).bytes)
def size(self):
return self.ds.size()
def get_data(self): def get_data(self):
with self._guard, _zmq_catch_error('PrefetchDataZMQ'): with self._guard, _zmq_catch_error('PrefetchDataZMQ'):
for k in itertools.count(): for k in itertools.count():
...@@ -311,7 +309,59 @@ class PrefetchOnGPUs(PrefetchDataZMQ): ...@@ -311,7 +309,59 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
proc.start() proc.start()
class MultiThreadMapData(ProxyDataFlow): class _ParallelMapData(ProxyDataFlow):
def __init__(self, ds, buffer_size):
super(_ParallelMapData, self).__init__(ds)
assert buffer_size > 0, buffer_size
self._buffer_size = buffer_size
def _recv(self):
pass
def _send(self, dp):
pass
def _recv_filter_none(self):
ret = self._recv()
assert ret is not None, \
"[{}] Map function cannot return None when strict mode is used.".format(type(self).__name__)
return ret
def _fill_buffer(self):
try:
for _ in range(self._buffer_size):
dp = next(self._iter)
self._send(dp)
except StopIteration:
logger.error(
"[{}] buffer_size cannot be larger than the size of the DataFlow!".format(type(self).__name__))
raise
def get_data_non_strict(self):
for dp in self._iter:
self._send(dp)
yield self._recv()
self._iter = self.ds.get_data() # refresh
for _ in range(self._buffer_size):
self._send(next(self._iter))
yield self._recv()
def get_data_strict(self):
for dp in self._iter:
self._send(dp)
yield self._recv_filter_none()
self._iter = self.ds.get_data() # refresh
# first clear the buffer, then fill
for k in range(self._buffer_size):
dp = self._recv_filter_none()
if k == self._buffer_size - 1:
self._fill_buffer()
yield dp
class MultiThreadMapData(_ParallelMapData):
""" """
Same as :class:`MapData`, but start threads to run the mapping function. Same as :class:`MapData`, but start threads to run the mapping function.
This is useful when the mapping function is the bottleneck, but you don't This is useful when the mapping function is the bottleneck, but you don't
...@@ -367,11 +417,10 @@ class MultiThreadMapData(ProxyDataFlow): ...@@ -367,11 +417,10 @@ class MultiThreadMapData(ProxyDataFlow):
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
""" """
super(MultiThreadMapData, self).__init__(ds) super(MultiThreadMapData, self).__init__(ds, buffer_size)
self._strict = strict self._strict = strict
self.nr_thread = nr_thread self.nr_thread = nr_thread
self.buffer_size = buffer_size
self.map_func = map_func self.map_func = map_func
self._threads = [] self._threads = []
self._evt = None self._evt = None
...@@ -398,43 +447,20 @@ class MultiThreadMapData(ProxyDataFlow): ...@@ -398,43 +447,20 @@ class MultiThreadMapData(ProxyDataFlow):
# only call once, to ensure inq+outq has a total of buffer_size elements # only call once, to ensure inq+outq has a total of buffer_size elements
self._fill_buffer() self._fill_buffer()
def _fill_buffer(self):
n = self.buffer_size - self._in_queue.qsize() - self._out_queue.qsize()
assert n >= 0, n
if n == 0:
return
try:
for _ in range(n):
self._in_queue.put(next(self._iter))
except StopIteration:
logger.error("[MultiThreadMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise
def _recv(self): def _recv(self):
ret = self._out_queue.get() return self._out_queue.get()
if ret is None:
assert not self._strict, \ def _send(self, dp):
"[MultiThreadMapData] Map function cannot return None when strict mode is used." self._in_queue.put(dp)
return ret
def get_data(self): def get_data(self):
with self._guard: with self._guard:
for dp in self._iter:
self._in_queue.put(dp)
yield self._recv()
self._iter = self.ds.get_data()
if self._strict: if self._strict:
# first call get() to clear the queues, then fill for dp in self.get_data_strict():
for k in range(self.buffer_size):
dp = self._recv()
if k == self.buffer_size - 1:
self._fill_buffer()
yield dp yield dp
else: else:
for _ in range(self.buffer_size): for dp in self.get_data_non_strict():
self._in_queue.put(next(self._iter)) yield dp
yield self._recv()
def __del__(self): def __del__(self):
if self._evt is not None: if self._evt is not None:
...@@ -447,10 +473,20 @@ class MultiThreadMapData(ProxyDataFlow): ...@@ -447,10 +473,20 @@ class MultiThreadMapData(ProxyDataFlow):
ThreadedMapData = MultiThreadMapData ThreadedMapData = MultiThreadMapData
class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow): class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
""" """
Same as :class:`MapData`, but start processes to run the mapping function, Same as :class:`MapData`, but start processes to run the mapping function,
and communicate with ZeroMQ pipe. and communicate with ZeroMQ pipe.
Note:
1. Processes run in parallel and can take different time to run the
mapping function. Therefore the order of datapoints won't be
preserved, and datapoints from one pass of `df.get_data()` might get
mixed with datapoints from the next pass.
You can use **strict mode**, where `MultiProcessMapData.get_data()`
is guranteed to produce the exact set which `df.get_data()`
produces. Although the order of data still isn't preserved.
""" """
class _Worker(mp.Process): class _Worker(mp.Process):
def __init__(self, identity, map_func, pipename, hwm): def __init__(self, identity, map_func, pipename, hwm):
...@@ -472,30 +508,32 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow): ...@@ -472,30 +508,32 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
dp = self.map_func(dp) dp = self.map_func(dp)
socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
def __init__(self, ds, nr_proc, map_func, buffer_size=200): def __init__(self, ds, nr_proc, map_func, buffer_size=200, strict=False):
""" """
Args: Args:
ds (DataFlow): the dataflow to map ds (DataFlow): the dataflow to map
nr_proc(int): number of threads to use nr_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None map_func (callable): datapoint -> datapoint | None
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
""" """
super(MultiProcessMapDataZMQ, self).__init__(ds) _ParallelMapData.__init__(self, ds, buffer_size)
_MultiProcessZMQDataFlow.__init__(self)
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.map_func = map_func self.map_func = map_func
self.buffer_size = buffer_size self._strict = strict
self._procs = [] self._procs = []
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
def _reset_once(self): def _reset_once(self):
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER) self.socket = self.context.socket(zmq.ROUTER)
self.socket.set_hwm(self.buffer_size * 2) self.socket.set_hwm(self._buffer_size * 2)
pipename = _get_pipe_name('dataflow-map') pipename = _get_pipe_name('dataflow-map')
_bind_guard(self.socket, pipename) _bind_guard(self.socket, pipename)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)] self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)]
worker_hwm = int(self.buffer_size * 2 // self.nr_proc) worker_hwm = int(self._buffer_size * 2 // self.nr_proc)
self._procs = [MultiProcessMapDataZMQ._Worker( self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm) self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)] for k in range(self.nr_proc)]
...@@ -507,14 +545,8 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow): ...@@ -507,14 +545,8 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
self._start_processes() self._start_processes()
self._fill_buffer() self._fill_buffer()
def _fill_buffer(self): def reset_state(self):
# Filling the buffer. _MultiProcessZMQDataFlow.reset_state(self)
try:
for _ in range(self.buffer_size):
self._send(next(self._iter))
except StopIteration:
logger.error("[MultiProcessMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise
def _send(self, dp): def _send(self, dp):
# round-robin assignment # round-robin assignment
...@@ -529,14 +561,12 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow): ...@@ -529,14 +561,12 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
def get_data(self): def get_data(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'): with self._guard, _zmq_catch_error('MultiProcessMapData'):
for dp in self._iter: if self._strict:
self._send(dp) for dp in self.get_data_strict():
yield self._recv() yield dp
else:
self._iter = self.ds.get_data() # refresh for dp in self.get_data_non_strict():
for _ in range(self.buffer_size): yield dp
self._send(next(self._iter))
yield self._recv()
MultiProcessMapData = MultiProcessMapDataZMQ # alias MultiProcessMapData = MultiProcessMapDataZMQ # alias
...@@ -549,13 +579,13 @@ if __name__ == '__main__': ...@@ -549,13 +579,13 @@ if __name__ == '__main__':
def get_data(self): def get_data(self):
for k in range(self._size): for k in range(self._size):
yield [0] yield [k]
def size(self): def size(self):
return self._size return self._size
ds = Zero(300) ds = Zero(300)
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1]) ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1], strict=True)
ds.reset_state() ds.reset_state()
for k in ds.get_data(): for k in ds.get_data():
print("Bang!", k) print("Bang!", k)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment