Commit 3a90a5c9 authored by Yuxin Wu's avatar Yuxin Wu

blockparallel & prefetchwithgpus

parent 12bf21bc
...@@ -2,30 +2,30 @@ ...@@ -2,30 +2,30 @@
# File: prefetch.py # File: prefetch.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import multiprocessing import multiprocessing as mp
from threading import Thread from threading import Thread
import itertools import itertools
from six.moves import range from six.moves import range, zip
from six.moves.queue import Queue from six.moves.queue import Queue
import uuid import uuid
import os import os
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils.concurrency import * from ..utils.concurrency import *
from ..utils.serialize import * from ..utils.serialize import loads, dumps
from ..utils import logger from ..utils import logger, change_env
try: try:
import zmq import zmq
except ImportError: except ImportError:
logger.warn("Error in 'import zmq'. PrefetchDataZMQ won't be available.") logger.warn("Error in 'import zmq'. PrefetchDataZMQ won't be available.")
__all__ = ['PrefetchData'] __all__ = ['PrefetchData', 'BlockParallel']
else: else:
__all__ = ['PrefetchData', 'PrefetchDataZMQ'] __all__.extend(['PrefetchDataZMQ', 'PrefetchOnGPUs'])
class PrefetchProcess(multiprocessing.Process): class PrefetchProcess(mp.Process):
def __init__(self, ds, queue): def __init__(self, ds, queue, reset_after_spawn=True):
""" """
:param ds: ds to take data from :param ds: ds to take data from
:param queue: output queue to put results in :param queue: output queue to put results in
...@@ -33,8 +33,10 @@ class PrefetchProcess(multiprocessing.Process): ...@@ -33,8 +33,10 @@ class PrefetchProcess(multiprocessing.Process):
super(PrefetchProcess, self).__init__() super(PrefetchProcess, self).__init__()
self.ds = ds self.ds = ds
self.queue = queue self.queue = queue
self.reset_after_spawn = reset_after_spawn
def run(self): def run(self):
if self.reset_after_spawn:
# reset all ds so each process will produce different data # reset all ds so each process will produce different data
self.ds.reset_state() self.ds.reset_state()
while True: while True:
...@@ -59,7 +61,7 @@ class PrefetchData(ProxyDataFlow): ...@@ -59,7 +61,7 @@ class PrefetchData(ProxyDataFlow):
self._size = -1 self._size = -1
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch self.nr_prefetch = nr_prefetch
self.queue = multiprocessing.Queue(self.nr_prefetch) self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue) self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
ensure_proc_terminate(self.procs) ensure_proc_terminate(self.procs)
...@@ -77,7 +79,16 @@ class PrefetchData(ProxyDataFlow): ...@@ -77,7 +79,16 @@ class PrefetchData(ProxyDataFlow):
# do nothing. all ds are reset once and only once in spawned processes # do nothing. all ds are reset once and only once in spawned processes
pass pass
class PrefetchProcessZMQ(multiprocessing.Process): def BlockParallel(ds, queue_size):
"""
Insert `BlockParallel` in dataflow pipeline to block parallelism on ds
:param ds: a `DataFlow`
:param queue_size: size of the queue used
"""
return PrefetchData(ds, queue_size, 1)
class PrefetchProcessZMQ(mp.Process):
def __init__(self, ds, conn_name): def __init__(self, ds, conn_name):
""" """
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
...@@ -158,3 +169,16 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -158,3 +169,16 @@ class PrefetchDataZMQ(ProxyDataFlow):
logger.info("Prefetch process exited.") logger.info("Prefetch process exited.")
except: except:
pass pass
class PrefetchOnGPUs(PrefetchDataZMQ):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES"""
def __init__(self, ds, gpus, pipedir=None):
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
self.gpus = gpus
def start_processes(self):
with mask_sigint():
for gpu, proc in zip(self.gpus, self.procs):
with change_gpu(gpu):
proc.start()
...@@ -18,3 +18,4 @@ def _global_import(name): ...@@ -18,3 +18,4 @@ def _global_import(name):
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
_global_import('naming') _global_import('naming')
_global_import('utils') _global_import('utils')
_global_import('gpu')
...@@ -7,6 +7,7 @@ import threading ...@@ -7,6 +7,7 @@ import threading
import multiprocessing import multiprocessing
import atexit import atexit
import bisect import bisect
from contextlib import contextmanager
import signal import signal
import weakref import weakref
import six import six
...@@ -20,7 +21,7 @@ from . import logger ...@@ -20,7 +21,7 @@ from . import logger
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate', __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE', 'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
'start_proc_mask_signal'] 'mask_sigint', 'start_proc_mask_signal']
class StoppableThread(threading.Thread): class StoppableThread(threading.Thread):
""" """
...@@ -106,14 +107,20 @@ def ensure_proc_terminate(proc): ...@@ -106,14 +107,20 @@ def ensure_proc_terminate(proc):
assert isinstance(proc, multiprocessing.Process) assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
@contextmanager
def mask_sigint():
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
yield
signal.signal(signal.SIGINT, sigint_handler)
def start_proc_mask_signal(proc): def start_proc_mask_signal(proc):
if not isinstance(proc, list): if not isinstance(proc, list):
proc = [proc] proc = [proc]
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) with mask_sigint():
for p in proc: for p in proc:
p.start() p.start()
signal.signal(signal.SIGINT, sigint_handler)
def subproc_call(cmd, timeout=None): def subproc_call(cmd, timeout=None):
try: try:
......
...@@ -64,25 +64,21 @@ def set_logger_dir(dirname, action=None): ...@@ -64,25 +64,21 @@ def set_logger_dir(dirname, action=None):
global LOG_FILE, LOG_DIR global LOG_FILE, LOG_DIR
if os.path.isdir(dirname): if os.path.isdir(dirname):
_logger.warn("""\ _logger.warn("""\
Directory {} exists! Please either backup/delete it, or use a new directory \ Directory {} exists! Please either backup/delete it, or use a new directory. \
unless you're resuming from a previous task.""".format(dirname)) If you're resuming from a previous run you can choose to keep it.""".format(dirname))
_logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):") _logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
if not action: while not action:
while True: action = input().lower().strip()
act = input().lower().strip()
if act:
break
else:
act = action act = action
if act == 'b': if act == 'b':
backup_name = dirname + get_time_str() backup_name = dirname + get_time_str()
shutil.move(dirname, backup_name) shutil.move(dirname, backup_name)
info("Directory'{}' backuped to '{}'".format(dirname, backup_name)) info("Directory '{}' backuped to '{}'".format(dirname, backup_name))
elif act == 'd': elif act == 'd':
shutil.rmtree(dirname) shutil.rmtree(dirname)
elif act == 'n': elif act == 'n':
dirname = dirname + get_time_str() dirname = dirname + get_time_str()
info("Use a different log directory {}".format(dirname)) info("Use a new log directory {}".format(dirname))
elif act == 'k': elif act == 'k':
pass pass
else: else:
......
...@@ -13,10 +13,9 @@ import six ...@@ -13,10 +13,9 @@ import six
from . import logger from . import logger
__all__ = ['change_env', 'map_arg', __all__ = ['change_env',
'map_arg',
'get_rng', 'memoized', 'get_rng', 'memoized',
'get_nr_gpu',
'get_gpus',
'get_dataset_dir', 'get_dataset_dir',
'get_tqdm_kwargs' 'get_tqdm_kwargs'
] ]
...@@ -96,16 +95,6 @@ def get_rng(obj=None): ...@@ -96,16 +95,6 @@ def get_rng(obj=None):
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295 int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed) return np.random.RandomState(seed)
def get_nr_gpu():
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO
return len(env.split(','))
def get_gpus():
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO
return map(int, env.strip().split(','))
def get_dataset_dir(*args): def get_dataset_dir(*args):
d = os.environ.get('TENSORPACK_DATASET', None) d = os.environ.get('TENSORPACK_DATASET', None)
if d is None: if d is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment