Commit a131b8fd authored by Yuxin Wu's avatar Yuxin Wu

Share stop event between ThreadedMapData workers

parent 1401d30d
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from __future__ import print_function
import threading
import multiprocessing as mp
import itertools
from six.moves import range, zip, queue
......@@ -270,8 +271,8 @@ class ThreadedMapData(ProxyDataFlow):
produces. Although the order of data still isn't preserved.
"""
class _WorkerThread(StoppableThread):
def __init__(self, inq, outq, map_func, strict):
super(ThreadedMapData._WorkerThread, self).__init__()
def __init__(self, inq, outq, evt, map_func, strict):
super(ThreadedMapData._WorkerThread, self).__init__(evt)
self.inq = inq
self.outq = outq
self.func = map_func
......@@ -279,14 +280,25 @@ class ThreadedMapData(ProxyDataFlow):
self._strict = strict
def run(self):
while not self.stopped():
dp = self.queue_get_stoppable(self.inq)
dp = self.func(dp)
if dp is not None:
self.outq.put(dp)
try:
while True:
dp = self.queue_get_stoppable(self.inq)
if self.stopped():
return
dp = self.func(dp)
if dp is not None:
self.outq.put(dp)
else:
assert not self._strict, \
"[ThreadedMapData] Map function cannot return None when strict mode is used."
except:
if self.stopped():
pass # skip duplicated error messages
else:
assert not self._strict, \
"[ThreadedMapData] Map function cannot return None when strict mode is used."
raise
finally:
self.stop()
def __init__(self, ds, nr_thread, map_func, buffer_size=200, strict=False):
"""
......@@ -308,13 +320,16 @@ class ThreadedMapData(ProxyDataFlow):
def reset_state(self):
super(ThreadedMapData, self).reset_state()
for t in self._threads:
t.stop()
t.join()
if self._threads:
self._threads[0].stop()
for t in self._threads:
t.join()
self._in_queue = queue.Queue()
self._out_queue = queue.Queue()
self._evt = threading.Event()
self._threads = [ThreadedMapData._WorkerThread(
self._in_queue, self._out_queue, self.map_func, self._strict)
self._in_queue, self._out_queue, self._evt, self.map_func, self._strict)
for _ in range(self.nr_thread)]
for t in self._threads:
t.start()
......@@ -357,6 +372,6 @@ class ThreadedMapData(ProxyDataFlow):
yield self._out_queue.get()
def __del__(self):
self._evt.set()
for p in self._threads:
p.stop()
p.join()
......@@ -32,9 +32,15 @@ class StoppableThread(threading.Thread):
A thread that has a 'stop' event.
"""
def __init__(self):
def __init__(self, evt=None):
"""
Args:
evt(threading.Event): if None, will create one.
"""
super(StoppableThread, self).__init__()
self._stop_evt = threading.Event()
if evt is None:
evt = threading.Event()
self._stop_evt = evt
def stop(self):
""" Stop the thread"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment