Commit 7c694aca authored by Yuxin Wu's avatar Yuxin Wu

threaded mapdata

parent d5143723
......@@ -48,3 +48,4 @@ class StartProcOrThread(Callback):
elif isinstance(k, StoppableThread):
logger.info("Stopping {} ...".format(k.name))
k.stop()
k.join()
......@@ -102,7 +102,7 @@ class Callbacks(Callback):
continue
for f in fetch:
ret.append(f)
self._cbid_to_fetchid[idx].append(len(ret)-1)
self._cbid_to_fetchid[idx].append(len(ret) - 1)
self._extra_fetches_cache = ret
return ret
......
......@@ -205,9 +205,8 @@ class MapData(ProxyDataFlow):
yield ret
class MapDataComponent(ProxyDataFlow):
class MapDataComponent(MapData):
""" Apply a mapper/filter on a datapoint component"""
def __init__(self, ds, func, index=0):
"""
Args:
......@@ -217,16 +216,13 @@ class MapDataComponent(ProxyDataFlow):
Note that if you use the filter feature, ``ds.size()`` will be incorrect.
index (int): index of the component.
"""
super(MapDataComponent, self).__init__(ds)
self.func = func
self.index = index
def get_data(self):
for dp in self.ds.get_data():
repl = self.func(dp[self.index])
if repl is not None:
dp[self.index] = repl # NOTE modifying
yield dp
def f(dp):
r = func(dp[index])
if r is None:
return None
dp[index] = r
return dp
super(MapDataComponent, self).__init__(ds, f)
class RepeatedData(ProxyDataFlow):
......
......@@ -60,8 +60,23 @@ class AugmentImageComponent(MapDataComponent):
self.augs = augmentors
else:
self.augs = AugmentorList(augmentors)
self._nr_error = 0
def func(x):
try:
ret = self.augs.augment(x)
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0:
logger.warn("Got {} augmentation errors.".format(self._nr_error))
return None
return ret
super(AugmentImageComponent, self).__init__(
ds, lambda x: self.augs.augment(x), index)
ds, func, index)
def reset_state(self):
self.ds.reset_state()
......
......@@ -5,23 +5,24 @@
from __future__ import print_function
import multiprocessing as mp
import itertools
from six.moves import range, zip
from six.moves import range, zip, queue
import uuid
import os
import zmq
from .base import ProxyDataFlow
from .common import RepeatedData
from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal)
mask_sigint, start_proc_mask_signal,
StoppableThread)
from ..utils.serialize import loads, dumps
from ..utils import logger
from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'BlockParallel', 'PrefetchDataZMQ', 'PrefetchOnGPUs']
__all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs']
class PrefetchProcess(mp.Process):
def __init__(self, ds, queue, reset_after_spawn=True):
"""
:param ds: ds to take data from
......@@ -49,7 +50,6 @@ class PrefetchData(ProxyDataFlow):
This is significantly slower than :class:`PrefetchDataZMQ` when data
is large.
"""
def __init__(self, ds, nr_prefetch, nr_proc=1):
"""
Args:
......@@ -83,31 +83,18 @@ class PrefetchData(ProxyDataFlow):
pass
def BlockParallel(ds, queue_size):
"""
Insert ``BlockParallel`` in dataflow pipeline to block parallelism on ds.
:param ds: a `DataFlow`
:param queue_size: size of the queue used
"""
return PrefetchData(ds, queue_size, 1)
class PrefetchProcessZMQ(mp.Process):
def __init__(self, ds, conn_name):
"""
:param ds: a `DataFlow` instance.
:param conn_name: the name of the IPC connection
"""
def __init__(self, ds, conn_name, hwm):
super(PrefetchProcessZMQ, self).__init__()
self.ds = ds
self.conn_name = conn_name
self.hwm = hwm
def run(self):
self.ds.reset_state()
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH)
self.socket.set_hwm(5)
self.socket.set_hwm(self.hwm)
self.socket.connect(self.conn_name)
while True:
for dp in self.ds.get_data():
......@@ -119,13 +106,14 @@ class PrefetchDataZMQ(ProxyDataFlow):
Prefetch data from a DataFlow using multiple processes, with ZMQ for
communication.
"""
def __init__(self, ds, nr_proc=1, pipedir=None):
def __init__(self, ds, nr_proc=1, pipedir=None, hwm=50):
"""
Args:
ds (DataFlow): input DataFlow.
nr_proc (int): number of processes to use.
pipedir (str): a local directory where the pipes should be put.
Useful if you're running on non-local FS such as NFS or GlusterFS.
hwm (int): the zmq "high-water mark" for both sender and receiver.
"""
super(PrefetchDataZMQ, self).__init__(ds)
try:
......@@ -141,10 +129,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir
self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6]
self.socket.set_hwm(5) # a little bit faster than default, don't know why
self.socket.set_hwm(hwm)
self.socket.bind(self.pipename)
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename)
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename, hwm)
for _ in range(self.nr_proc)]
self.start_processes()
# __del__ not guranteed to get called at exit
......@@ -206,3 +194,71 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
for gpu, proc in zip(self.gpus, self.procs):
with change_gpu(gpu):
proc.start()
class ThreadedMapData(ProxyDataFlow):
"""
Same as :class:`MapData`, but start threads to run the mapping function.
This is useful when the mapping function is the bottleneck, but you don't
want to start processes for the entire dataflow pipeline.
With threads, there are tiny communication overhead, but due to GIL, you
should avoid starting the threads in your main process.
Note that the threads will only start in the process which calls
`reset_state()`.
"""
class WorkerThread(StoppableThread):
def __init__(self, inq, outq, map_func):
self.inq = inq
self.outq = outq
self.func = map_func
def run(self):
while not self.stopped():
dp = self.queue_get_stoppable(self.inq)
dp = self.func(dp)
if dp is not None:
self.queue_put_stoppable(self.outq, dp)
def __init__(self, ds, nr_thread, map_func, buffer_size=200):
"""
Args:
pass
"""
super(ThreadedMapData, self).__init__(ds)
self.infinite_ds = RepeatedData(ds, -1)
self.nr_thread = nr_thread
self.buffer_size = buffer_size
self.map_func = map_func
self._threads = []
def reset_state(self):
super(ThreadedMapData, self).reset_state()
for t in self._threads:
t.stop()
t.join()
self._in_queue = queue.Queue()
self._out_queue = queue.Queue()
self._threads = [ThreadedMapData.WorkerThread(
self._in_queue, self._out_queue, self.map_func)
for _ in range(self.nr_thread)]
for t in self._threads:
t.start()
# fill the buffer
self._itr = self.infinite_ds.get_data()
self._fill_buffer()
def _fill_buffer(self):
n = self.buffer_size - self._in_queue.qsize() - self._out_queue.qsize()
if n <= 0:
return
for _ in range(n):
self._in_queue.put(next(self._itr))
def get_data(self):
self._fill_buffer()
sz = self.size()
for _ in range(sz):
self._in_queue.put(next(self._itr))
yield self._out_queue.get()
......@@ -22,7 +22,7 @@ _TO_IMPORT = set([
'gradproc',
'argscope',
'tower'
])
])
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in iter_modules(
......
[flake8]
max-line-length = 120
ignore = E265
exclude = .git,
tensorpack/__init__.py,
setup.py,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment