Commit 3a90a5c9 authored by Yuxin Wu's avatar Yuxin Wu

blockparallel & prefetchwithgpus

parent 12bf21bc
......@@ -2,30 +2,30 @@
# File: prefetch.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import multiprocessing
import multiprocessing as mp
from threading import Thread
import itertools
from six.moves import range
from six.moves import range, zip
from six.moves.queue import Queue
import uuid
import os
from .base import ProxyDataFlow
from ..utils.concurrency import *
from ..utils.serialize import *
from ..utils import logger
from ..utils.serialize import loads, dumps
from ..utils import logger, change_env
try:
import zmq
except ImportError:
logger.warn("Error in 'import zmq'. PrefetchDataZMQ won't be available.")
__all__ = ['PrefetchData']
__all__ = ['PrefetchData', 'BlockParallel']
else:
__all__ = ['PrefetchData', 'PrefetchDataZMQ']
__all__.extend(['PrefetchDataZMQ', 'PrefetchOnGPUs'])
class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue):
class PrefetchProcess(mp.Process):
def __init__(self, ds, queue, reset_after_spawn=True):
"""
:param ds: ds to take data from
:param queue: output queue to put results in
......@@ -33,8 +33,10 @@ class PrefetchProcess(multiprocessing.Process):
super(PrefetchProcess, self).__init__()
self.ds = ds
self.queue = queue
self.reset_after_spawn = reset_after_spawn
def run(self):
if self.reset_after_spawn:
# reset all ds so each process will produce different data
self.ds.reset_state()
while True:
......@@ -59,7 +61,7 @@ class PrefetchData(ProxyDataFlow):
self._size = -1
self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch
self.queue = multiprocessing.Queue(self.nr_prefetch)
self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)]
ensure_proc_terminate(self.procs)
......@@ -77,7 +79,16 @@ class PrefetchData(ProxyDataFlow):
# do nothing. all ds are reset once and only once in spawned processes
pass
class PrefetchProcessZMQ(multiprocessing.Process):
def BlockParallel(ds, queue_size):
"""
Insert `BlockParallel` in dataflow pipeline to block parallelism on ds
:param ds: a `DataFlow`
:param queue_size: size of the queue used
"""
return PrefetchData(ds, queue_size, 1)
class PrefetchProcessZMQ(mp.Process):
def __init__(self, ds, conn_name):
"""
:param ds: a `DataFlow` instance.
......@@ -158,3 +169,16 @@ class PrefetchDataZMQ(ProxyDataFlow):
logger.info("Prefetch process exited.")
except:
pass
class PrefetchOnGPUs(PrefetchDataZMQ):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES"""
def __init__(self, ds, gpus, pipedir=None):
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
self.gpus = gpus
def start_processes(self):
with mask_sigint():
for gpu, proc in zip(self.gpus, self.procs):
with change_gpu(gpu):
proc.start()
......@@ -18,3 +18,4 @@ def _global_import(name):
globals()[k] = p.__dict__[k]
_global_import('naming')
_global_import('utils')
_global_import('gpu')
......@@ -7,6 +7,7 @@ import threading
import multiprocessing
import atexit
import bisect
from contextlib import contextmanager
import signal
import weakref
import six
......@@ -20,7 +21,7 @@ from . import logger
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
'start_proc_mask_signal']
'mask_sigint', 'start_proc_mask_signal']
class StoppableThread(threading.Thread):
"""
......@@ -106,14 +107,20 @@ def ensure_proc_terminate(proc):
assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
@contextmanager
def mask_sigint():
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
yield
signal.signal(signal.SIGINT, sigint_handler)
def start_proc_mask_signal(proc):
if not isinstance(proc, list):
proc = [proc]
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
with mask_sigint():
for p in proc:
p.start()
signal.signal(signal.SIGINT, sigint_handler)
def subproc_call(cmd, timeout=None):
try:
......
......@@ -64,25 +64,21 @@ def set_logger_dir(dirname, action=None):
global LOG_FILE, LOG_DIR
if os.path.isdir(dirname):
_logger.warn("""\
Directory {} exists! Please either backup/delete it, or use a new directory \
unless you're resuming from a previous task.""".format(dirname))
Directory {} exists! Please either backup/delete it, or use a new directory. \
If you're resuming from a previous run you can choose to keep it.""".format(dirname))
_logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
if not action:
while True:
act = input().lower().strip()
if act:
break
else:
while not action:
action = input().lower().strip()
act = action
if act == 'b':
backup_name = dirname + get_time_str()
shutil.move(dirname, backup_name)
info("Directory'{}' backuped to '{}'".format(dirname, backup_name))
info("Directory '{}' backuped to '{}'".format(dirname, backup_name))
elif act == 'd':
shutil.rmtree(dirname)
elif act == 'n':
dirname = dirname + get_time_str()
info("Use a different log directory {}".format(dirname))
info("Use a new log directory {}".format(dirname))
elif act == 'k':
pass
else:
......
......@@ -13,10 +13,9 @@ import six
from . import logger
__all__ = ['change_env', 'map_arg',
__all__ = ['change_env',
'map_arg',
'get_rng', 'memoized',
'get_nr_gpu',
'get_gpus',
'get_dataset_dir',
'get_tqdm_kwargs'
]
......@@ -96,16 +95,6 @@ def get_rng(obj=None):
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed)
def get_nr_gpu():
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO
return len(env.split(','))
def get_gpus():
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO
return map(int, env.strip().split(','))
def get_dataset_dir(*args):
d = os.environ.get('TENSORPACK_DATASET', None)
if d is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment