Commit 7c694aca authored by Yuxin Wu's avatar Yuxin Wu

threaded mapdata

parent d5143723
...@@ -48,3 +48,4 @@ class StartProcOrThread(Callback): ...@@ -48,3 +48,4 @@ class StartProcOrThread(Callback):
elif isinstance(k, StoppableThread): elif isinstance(k, StoppableThread):
logger.info("Stopping {} ...".format(k.name)) logger.info("Stopping {} ...".format(k.name))
k.stop() k.stop()
k.join()
...@@ -102,7 +102,7 @@ class Callbacks(Callback): ...@@ -102,7 +102,7 @@ class Callbacks(Callback):
continue continue
for f in fetch: for f in fetch:
ret.append(f) ret.append(f)
self._cbid_to_fetchid[idx].append(len(ret)-1) self._cbid_to_fetchid[idx].append(len(ret) - 1)
self._extra_fetches_cache = ret self._extra_fetches_cache = ret
return ret return ret
......
...@@ -205,9 +205,8 @@ class MapData(ProxyDataFlow): ...@@ -205,9 +205,8 @@ class MapData(ProxyDataFlow):
yield ret yield ret
class MapDataComponent(ProxyDataFlow): class MapDataComponent(MapData):
""" Apply a mapper/filter on a datapoint component""" """ Apply a mapper/filter on a datapoint component"""
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
""" """
Args: Args:
...@@ -217,16 +216,13 @@ class MapDataComponent(ProxyDataFlow): ...@@ -217,16 +216,13 @@ class MapDataComponent(ProxyDataFlow):
Note that if you use the filter feature, ``ds.size()`` will be incorrect. Note that if you use the filter feature, ``ds.size()`` will be incorrect.
index (int): index of the component. index (int): index of the component.
""" """
super(MapDataComponent, self).__init__(ds) def f(dp):
self.func = func r = func(dp[index])
self.index = index if r is None:
return None
def get_data(self): dp[index] = r
for dp in self.ds.get_data(): return dp
repl = self.func(dp[self.index]) super(MapDataComponent, self).__init__(ds, f)
if repl is not None:
dp[self.index] = repl # NOTE modifying
yield dp
class RepeatedData(ProxyDataFlow): class RepeatedData(ProxyDataFlow):
......
...@@ -60,8 +60,23 @@ class AugmentImageComponent(MapDataComponent): ...@@ -60,8 +60,23 @@ class AugmentImageComponent(MapDataComponent):
self.augs = augmentors self.augs = augmentors
else: else:
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self._nr_error = 0
def func(x):
try:
ret = self.augs.augment(x)
except KeyboardInterrupt:
raise
except Exception:
self._nr_error += 1
if self._nr_error % 1000 == 0:
logger.warn("Got {} augmentation errors.".format(self._nr_error))
return None
return ret
super(AugmentImageComponent, self).__init__( super(AugmentImageComponent, self).__init__(
ds, lambda x: self.augs.augment(x), index) ds, func, index)
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
......
...@@ -5,23 +5,24 @@ ...@@ -5,23 +5,24 @@
from __future__ import print_function from __future__ import print_function
import multiprocessing as mp import multiprocessing as mp
import itertools import itertools
from six.moves import range, zip from six.moves import range, zip, queue
import uuid import uuid
import os import os
import zmq import zmq
from .base import ProxyDataFlow from .base import ProxyDataFlow
from .common import RepeatedData
from ..utils.concurrency import (ensure_proc_terminate, from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal) mask_sigint, start_proc_mask_signal,
StoppableThread)
from ..utils.serialize import loads, dumps from ..utils.serialize import loads, dumps
from ..utils import logger from ..utils import logger
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'BlockParallel', 'PrefetchDataZMQ', 'PrefetchOnGPUs'] __all__ = ['PrefetchData', 'PrefetchDataZMQ', 'PrefetchOnGPUs']
class PrefetchProcess(mp.Process): class PrefetchProcess(mp.Process):
def __init__(self, ds, queue, reset_after_spawn=True): def __init__(self, ds, queue, reset_after_spawn=True):
""" """
:param ds: ds to take data from :param ds: ds to take data from
...@@ -49,7 +50,6 @@ class PrefetchData(ProxyDataFlow): ...@@ -49,7 +50,6 @@ class PrefetchData(ProxyDataFlow):
This is significantly slower than :class:`PrefetchDataZMQ` when data This is significantly slower than :class:`PrefetchDataZMQ` when data
is large. is large.
""" """
def __init__(self, ds, nr_prefetch, nr_proc=1): def __init__(self, ds, nr_prefetch, nr_proc=1):
""" """
Args: Args:
...@@ -83,31 +83,18 @@ class PrefetchData(ProxyDataFlow): ...@@ -83,31 +83,18 @@ class PrefetchData(ProxyDataFlow):
pass pass
def BlockParallel(ds, queue_size):
"""
Insert ``BlockParallel`` in dataflow pipeline to block parallelism on ds.
:param ds: a `DataFlow`
:param queue_size: size of the queue used
"""
return PrefetchData(ds, queue_size, 1)
class PrefetchProcessZMQ(mp.Process): class PrefetchProcessZMQ(mp.Process):
def __init__(self, ds, conn_name): def __init__(self, ds, conn_name, hwm):
"""
:param ds: a `DataFlow` instance.
:param conn_name: the name of the IPC connection
"""
super(PrefetchProcessZMQ, self).__init__() super(PrefetchProcessZMQ, self).__init__()
self.ds = ds self.ds = ds
self.conn_name = conn_name self.conn_name = conn_name
self.hwm = hwm
def run(self): def run(self):
self.ds.reset_state() self.ds.reset_state()
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH) self.socket = self.context.socket(zmq.PUSH)
self.socket.set_hwm(5) self.socket.set_hwm(self.hwm)
self.socket.connect(self.conn_name) self.socket.connect(self.conn_name)
while True: while True:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
...@@ -119,13 +106,14 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -119,13 +106,14 @@ class PrefetchDataZMQ(ProxyDataFlow):
Prefetch data from a DataFlow using multiple processes, with ZMQ for Prefetch data from a DataFlow using multiple processes, with ZMQ for
communication. communication.
""" """
def __init__(self, ds, nr_proc=1, pipedir=None): def __init__(self, ds, nr_proc=1, pipedir=None, hwm=50):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
nr_proc (int): number of processes to use. nr_proc (int): number of processes to use.
pipedir (str): a local directory where the pipes should be put. pipedir (str): a local directory where the pipes should be put.
Useful if you're running on non-local FS such as NFS or GlusterFS. Useful if you're running on non-local FS such as NFS or GlusterFS.
hwm (int): the zmq "high-water mark" for both sender and receiver.
""" """
super(PrefetchDataZMQ, self).__init__(ds) super(PrefetchDataZMQ, self).__init__(ds)
try: try:
...@@ -141,10 +129,10 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -141,10 +129,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.') pipedir = os.environ.get('TENSORPACK_PIPEDIR', '.')
assert os.path.isdir(pipedir), pipedir assert os.path.isdir(pipedir), pipedir
self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6] self.pipename = "ipc://{}/dataflow-pipe-".format(pipedir.rstrip('/')) + str(uuid.uuid1())[:6]
self.socket.set_hwm(5) # a little bit faster than default, don't know why self.socket.set_hwm(hwm)
self.socket.bind(self.pipename) self.socket.bind(self.pipename)
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename) self.procs = [PrefetchProcessZMQ(self.ds, self.pipename, hwm)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
self.start_processes() self.start_processes()
# __del__ not guranteed to get called at exit # __del__ not guranteed to get called at exit
...@@ -206,3 +194,71 @@ class PrefetchOnGPUs(PrefetchDataZMQ): ...@@ -206,3 +194,71 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
for gpu, proc in zip(self.gpus, self.procs): for gpu, proc in zip(self.gpus, self.procs):
with change_gpu(gpu): with change_gpu(gpu):
proc.start() proc.start()
class ThreadedMapData(ProxyDataFlow):
"""
Same as :class:`MapData`, but start threads to run the mapping function.
This is useful when the mapping function is the bottleneck, but you don't
want to start processes for the entire dataflow pipeline.
With threads, there are tiny communication overhead, but due to GIL, you
should avoid starting the threads in your main process.
Note that the threads will only start in the process which calls
`reset_state()`.
"""
class WorkerThread(StoppableThread):
def __init__(self, inq, outq, map_func):
self.inq = inq
self.outq = outq
self.func = map_func
def run(self):
while not self.stopped():
dp = self.queue_get_stoppable(self.inq)
dp = self.func(dp)
if dp is not None:
self.queue_put_stoppable(self.outq, dp)
def __init__(self, ds, nr_thread, map_func, buffer_size=200):
"""
Args:
pass
"""
super(ThreadedMapData, self).__init__(ds)
self.infinite_ds = RepeatedData(ds, -1)
self.nr_thread = nr_thread
self.buffer_size = buffer_size
self.map_func = map_func
self._threads = []
def reset_state(self):
super(ThreadedMapData, self).reset_state()
for t in self._threads:
t.stop()
t.join()
self._in_queue = queue.Queue()
self._out_queue = queue.Queue()
self._threads = [ThreadedMapData.WorkerThread(
self._in_queue, self._out_queue, self.map_func)
for _ in range(self.nr_thread)]
for t in self._threads:
t.start()
# fill the buffer
self._itr = self.infinite_ds.get_data()
self._fill_buffer()
def _fill_buffer(self):
n = self.buffer_size - self._in_queue.qsize() - self._out_queue.qsize()
if n <= 0:
return
for _ in range(n):
self._in_queue.put(next(self._itr))
def get_data(self):
self._fill_buffer()
sz = self.size()
for _ in range(sz):
self._in_queue.put(next(self._itr))
yield self._out_queue.get()
...@@ -22,7 +22,7 @@ _TO_IMPORT = set([ ...@@ -22,7 +22,7 @@ _TO_IMPORT = set([
'gradproc', 'gradproc',
'argscope', 'argscope',
'tower' 'tower'
]) ])
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in iter_modules( for _, module_name, _ in iter_modules(
......
...@@ -157,8 +157,8 @@ class Trainer(object): ...@@ -157,8 +157,8 @@ class Trainer(object):
def global_step(self): def global_step(self):
try: try:
return self._starting_step + \ return self._starting_step + \
self.config.steps_per_epoch * (self.epoch_num - 1) + \ self.config.steps_per_epoch * (self.epoch_num - 1) + \
self.local_step + 1 self.local_step + 1
except AttributeError: except AttributeError:
return get_global_step_value() return get_global_step_value()
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
import tensorflow as tf import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
Callbacks, MovingAverageSummary, Callbacks, MovingAverageSummary,
StatPrinter, ProgressBar, StatPrinter, ProgressBar,
MaintainStepCounter) MaintainStepCounter)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
...@@ -86,9 +86,9 @@ class TrainConfig(object): ...@@ -86,9 +86,9 @@ class TrainConfig(object):
assert_type(callbacks, list) assert_type(callbacks, list)
if extra_callbacks is None: if extra_callbacks is None:
extra_callbacks = [ extra_callbacks = [
MovingAverageSummary(), MovingAverageSummary(),
ProgressBar(), ProgressBar(),
StatPrinter()] StatPrinter()]
self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks
assert_type(self.callbacks, list) assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks) self.callbacks = Callbacks(self.callbacks)
......
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = E265
exclude = .git, exclude = .git,
tensorpack/__init__.py, tensorpack/__init__.py,
setup.py, setup.py,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment