Commit c9107add authored by Yuxin Wu's avatar Yuxin Wu

multigpu predictor

parent 1dcc0e72
...@@ -40,6 +40,7 @@ class DumpParamAsImage(Callback): ...@@ -40,6 +40,7 @@ class DumpParamAsImage(Callback):
self.clip = clip self.clip = clip
def _before_train(self): def _before_train(self):
# TODO might not work for multiGPU?
self.var = self.graph.get_tensor_by_name(self.var_name) self.var = self.graph.get_tensor_by_name(self.var_name)
def _trigger_epoch(self): def _trigger_epoch(self):
......
...@@ -6,7 +6,7 @@ import multiprocessing ...@@ -6,7 +6,7 @@ import multiprocessing
from six.moves import range from six.moves import range
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils.concurrency import ensure_procs_terminate from ..utils.concurrency import ensure_proc_terminate
from ..utils import logger from ..utils import logger
__all__ = ['PrefetchData'] __all__ = ['PrefetchData']
...@@ -36,7 +36,8 @@ class PrefetchData(ProxyDataFlow): ...@@ -36,7 +36,8 @@ class PrefetchData(ProxyDataFlow):
""" """
:param ds: a `DataFlow` instance. :param ds: a `DataFlow` instance.
:param nr_prefetch: size of the queue to hold prefetched datapoints. :param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_proc: number of processes to use. :param nr_proc: number of processes to use. When larger than 1, order
of data points will be random.
""" """
super(PrefetchData, self).__init__(ds) super(PrefetchData, self).__init__(ds)
self._size = self.size() self._size = self.size()
...@@ -45,7 +46,7 @@ class PrefetchData(ProxyDataFlow): ...@@ -45,7 +46,7 @@ class PrefetchData(ProxyDataFlow):
self.queue = multiprocessing.Queue(self.nr_prefetch) self.queue = multiprocessing.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue) self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
ensure_procs_terminate(self.procs) ensure_proc_terminate(self.procs)
for x in self.procs: for x in self.procs:
x.start() x.start()
......
...@@ -7,9 +7,13 @@ from itertools import count ...@@ -7,9 +7,13 @@ from itertools import count
import argparse import argparse
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
import bisect
from tqdm import tqdm from tqdm import tqdm
from six.moves import zip from six.moves import zip
import multiprocessing
from .utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from .tfutils import * from .tfutils import *
from .utils import logger from .utils import logger
from .tfutils.modelutils import describe_model from .tfutils.modelutils import describe_model
...@@ -50,6 +54,7 @@ class PredictConfig(object): ...@@ -50,6 +54,7 @@ class PredictConfig(object):
:param output_var_names: a list of names of the output variables to predict, the :param output_var_names: a list of names of the output variables to predict, the
variables can be any computable tensor in the graph. variables can be any computable tensor in the graph.
Predict specific output might not require all input variables. Predict specific output might not require all input variables.
:param nr_gpu: default to 1. Use CUDA_VISIBLE_DEVICES to control which GPU to use sepcifically.
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
...@@ -59,6 +64,7 @@ class PredictConfig(object): ...@@ -59,6 +64,7 @@ class PredictConfig(object):
self.model = kwargs.pop('model') self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None) self.input_data_mapping = kwargs.pop('input_data_mapping', None)
self.output_var_names = kwargs.pop('output_var_names') self.output_var_names = kwargs.pop('output_var_names')
self.nr_gpu = kwargs.pop('nr_gpu', 1)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def get_predict_func(config): def get_predict_func(config):
...@@ -81,8 +87,6 @@ def get_predict_func(config): ...@@ -81,8 +87,6 @@ def get_predict_func(config):
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1]) output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names] for n in output_var_names]
describe_model()
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
config.session_init.init(sess) config.session_init.init(sess)
...@@ -101,27 +105,103 @@ def get_predict_func(config): ...@@ -101,27 +105,103 @@ def get_predict_func(config):
PredictResult = namedtuple('PredictResult', ['input', 'output']) PredictResult = namedtuple('PredictResult', ['input', 'output'])
# TODO mutligpu predictor
class PredictWorker(multiprocessing.Process):
def __init__(self, idx, gpuid, inqueue, outqueue, config):
super(PredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.inqueue = inqueue
self.outqueue = outqueue
self.config = config
def run(self):
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:{}'.format(self.idx)):
self.func = get_predict_func(self.config)
if self.idx == 0:
describe_model()
while True:
tid, dp = self.inqueue.get()
if tid == DIE:
self.outqueue.put((DIE, None))
return
else:
res = PredictResult(dp, self.func(dp))
self.outqueue.put((tid, res))
def DFtoQueue(ds, size, nr_consumer):
q = multiprocessing.Queue(size)
class EnqueProc(multiprocessing.Process):
def __init__(self, ds, q, nr_consumer):
super(EnqueProc, self).__init__()
self.ds = ds
self.q = q
def run(self):
for idx, dp in enumerate(self.ds.get_data()):
self.q.put((idx, dp))
print "Enqueue ends"
for _ in range(nr_consumer):
self.q.put((DIE, None))
proc = EnqueProc(ds, q, nr_consumer)
return q, proc
class DatasetPredictor(object): class DatasetPredictor(object):
""" """
Run the predict_config on a given `DataFlow`. Run the predict_config on a given `DataFlow`.
""" """
def __init__(self, predict_config, dataset): def __init__(self, config, dataset):
""" """
:param predict_config: a `PredictConfig` instance. :param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance. :param dataset: a `DataFlow` instance.
""" """
assert isinstance(dataset, DataFlow) assert isinstance(dataset, DataFlow)
self.ds = dataset self.ds = dataset
self.predict_func = get_predict_func(predict_config) self.nr_gpu = config.nr_gpu
if self.nr_gpu > 1:
self.inqueue, self.inqueue_proc = DFtoQueue(self.ds, 10, self.nr_gpu)
self.outqueue = multiprocessing.Queue()
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
except KeyError:
gpus = range(self.nr_gpu)
self.workers = [PredictWorker(i, gpus[i], self.inqueue, self.outqueue, config)
for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue)
# run the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers)
ensure_proc_terminate([self.result_queue, self.inqueue_proc])
else:
self.func = get_predict_func(config)
def get_result(self): def get_result(self):
""" A generator to produce prediction for each data""" """ A generator to produce prediction for each data"""
with tqdm(total=self.ds.size()) as pbar: with tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data(): if self.nr_gpu == 1:
yield PredictResult(dp, self.predict_func(dp)) for dp in self.ds.get_data():
pbar.update() yield PredictResult(dp, self.func(dp))
pbar.update()
else:
while True:
res = self.result_queue.get()
if res[0] != DIE:
yield res[1]
else:
break
pbar.update()
self.inqueue_proc.join()
self.inqueue_proc.terminate()
for p in self.workers:
p.join(); p.terminate()
def get_all_result(self): def get_all_result(self):
""" """
......
...@@ -116,9 +116,11 @@ class QueueInputTrainer(Trainer): ...@@ -116,9 +116,11 @@ class QueueInputTrainer(Trainer):
# get gradients to update: # get gradients to update:
if self.config.nr_tower > 1: if self.config.nr_tower > 1:
logger.info("Training a model of {} tower".format(self.config.nr_tower)) logger.info("Training a model of {} tower".format(self.config.nr_tower))
# to avoid repeated summary from each device # to avoid repeated summary from each device
coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY] coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {} kept_summaries = {}
grad_list = [] grad_list = []
for i in range(self.config.nr_tower): for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \ with tf.device('/gpu:{}'.format(i)), \
......
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: concurrency.py # File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Credit belongs to Xinyu Zhou
import threading import threading
import multiprocessing import multiprocessing, multiprocess
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
import atexit import atexit
...@@ -12,6 +13,9 @@ from six.moves import zip ...@@ -12,6 +13,9 @@ from six.moves import zip
from .naming import * from .naming import *
__all__ = ['StoppableThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
class StoppableThread(threading.Thread): class StoppableThread(threading.Thread):
def __init__(self): def __init__(self):
super(StoppableThread, self).__init__() super(StoppableThread, self).__init__()
...@@ -24,7 +28,16 @@ class StoppableThread(threading.Thread): ...@@ -24,7 +28,16 @@ class StoppableThread(threading.Thread):
return self._stop.isSet() return self._stop.isSet()
class DIE(object):
pass
def ensure_proc_terminate(proc): def ensure_proc_terminate(proc):
if isinstance(proc, list):
for p in proc:
ensure_proc_terminate(p)
return
def stop_proc_by_weak_ref(ref): def stop_proc_by_weak_ref(ref):
proc = ref() proc = ref()
if proc is None: if proc is None:
...@@ -34,9 +47,58 @@ def ensure_proc_terminate(proc): ...@@ -34,9 +47,58 @@ def ensure_proc_terminate(proc):
proc.terminate() proc.terminate()
proc.join() proc.join()
assert isinstance(proc, multiprocessing.Process) assert isinstance(proc, (multiprocessing.Process, multiprocess.Process))
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def ensure_procs_terminate(procs):
for p in procs: class OrderedContainer(object):
ensure_proc_terminate(p) def __init__(self, start=0):
self.ranks = []
self.data = []
self.wait_for = start
def put(self, rank, val):
idx = bisect.bisect(self.ranks, rank)
self.ranks.insert(idx, rank)
self.data.insert(idx, val)
def has_next(self):
if len(self.ranks) == 0:
return False
return self.ranks[0] == self.wait_for
def get(self):
assert self.has_next()
ret = self.data[0]
rank = self.ranks[0]
del self.ranks[0]
del self.data[0]
self.wait_for += 1
return rank, ret
class OrderedResultGatherProc(multiprocessing.Process):
def __init__(self, data_queue, start=0):
super(self.__class__, self).__init__()
self.data_queue = data_queue
self.ordered_container = OrderedContainer(start=start)
self.result_queue = multiprocessing.Queue()
def run(self):
try:
while True:
task_id, data = self.data_queue.get()
if task_id == DIE:
self.result_queue.put((task_id, data))
else:
self.ordered_container.put(task_id, data)
while self.ordered_container.has_next():
self.result_queue.put(self.ordered_container.get())
except Exception as e:
import traceback
traceback.print_exc()
raise e
def get(self):
return self.result_queue.get()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment