Commit a131b8fd authored by Yuxin Wu's avatar Yuxin Wu

Share stop event between ThreadedMapData workers

parent 1401d30d
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from __future__ import print_function from __future__ import print_function
import threading
import multiprocessing as mp import multiprocessing as mp
import itertools import itertools
from six.moves import range, zip, queue from six.moves import range, zip, queue
...@@ -270,8 +271,8 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -270,8 +271,8 @@ class ThreadedMapData(ProxyDataFlow):
produces. Although the order of data still isn't preserved. produces. Although the order of data still isn't preserved.
""" """
class _WorkerThread(StoppableThread): class _WorkerThread(StoppableThread):
def __init__(self, inq, outq, map_func, strict): def __init__(self, inq, outq, evt, map_func, strict):
super(ThreadedMapData._WorkerThread, self).__init__() super(ThreadedMapData._WorkerThread, self).__init__(evt)
self.inq = inq self.inq = inq
self.outq = outq self.outq = outq
self.func = map_func self.func = map_func
...@@ -279,14 +280,25 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -279,14 +280,25 @@ class ThreadedMapData(ProxyDataFlow):
self._strict = strict self._strict = strict
def run(self): def run(self):
while not self.stopped(): try:
dp = self.queue_get_stoppable(self.inq) while True:
dp = self.func(dp) dp = self.queue_get_stoppable(self.inq)
if dp is not None: if self.stopped():
self.outq.put(dp) return
dp = self.func(dp)
if dp is not None:
self.outq.put(dp)
else:
assert not self._strict, \
"[ThreadedMapData] Map function cannot return None when strict mode is used."
except:
if self.stopped():
pass # skip duplicated error messages
else: else:
assert not self._strict, \ raise
"[ThreadedMapData] Map function cannot return None when strict mode is used." finally:
self.stop()
def __init__(self, ds, nr_thread, map_func, buffer_size=200, strict=False): def __init__(self, ds, nr_thread, map_func, buffer_size=200, strict=False):
""" """
...@@ -308,13 +320,16 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -308,13 +320,16 @@ class ThreadedMapData(ProxyDataFlow):
def reset_state(self): def reset_state(self):
super(ThreadedMapData, self).reset_state() super(ThreadedMapData, self).reset_state()
for t in self._threads: if self._threads:
t.stop() self._threads[0].stop()
t.join() for t in self._threads:
t.join()
self._in_queue = queue.Queue() self._in_queue = queue.Queue()
self._out_queue = queue.Queue() self._out_queue = queue.Queue()
self._evt = threading.Event()
self._threads = [ThreadedMapData._WorkerThread( self._threads = [ThreadedMapData._WorkerThread(
self._in_queue, self._out_queue, self.map_func, self._strict) self._in_queue, self._out_queue, self._evt, self.map_func, self._strict)
for _ in range(self.nr_thread)] for _ in range(self.nr_thread)]
for t in self._threads: for t in self._threads:
t.start() t.start()
...@@ -357,6 +372,6 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -357,6 +372,6 @@ class ThreadedMapData(ProxyDataFlow):
yield self._out_queue.get() yield self._out_queue.get()
def __del__(self): def __del__(self):
self._evt.set()
for p in self._threads: for p in self._threads:
p.stop()
p.join() p.join()
...@@ -32,9 +32,15 @@ class StoppableThread(threading.Thread): ...@@ -32,9 +32,15 @@ class StoppableThread(threading.Thread):
A thread that has a 'stop' event. A thread that has a 'stop' event.
""" """
def __init__(self): def __init__(self, evt=None):
"""
Args:
evt(threading.Event): if None, will create one.
"""
super(StoppableThread, self).__init__() super(StoppableThread, self).__init__()
self._stop_evt = threading.Event() if evt is None:
evt = threading.Event()
self._stop_evt = evt
def stop(self): def stop(self):
""" Stop the thread""" """ Stop the thread"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment