Commit c9107add authored by Yuxin Wu's avatar Yuxin Wu

multigpu predictor

parent 1dcc0e72
......@@ -40,6 +40,7 @@ class DumpParamAsImage(Callback):
self.clip = clip
def _before_train(self):
# TODO might not work for multiGPU?
self.var = self.graph.get_tensor_by_name(self.var_name)
def _trigger_epoch(self):
......
......@@ -6,7 +6,7 @@ import multiprocessing
from six.moves import range
from .base import ProxyDataFlow
from ..utils.concurrency import ensure_procs_terminate
from ..utils.concurrency import ensure_proc_terminate
from ..utils import logger
__all__ = ['PrefetchData']
......@@ -36,7 +36,8 @@ class PrefetchData(ProxyDataFlow):
"""
:param ds: a `DataFlow` instance.
:param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_proc: number of processes to use.
:param nr_proc: number of processes to use. When larger than 1, order
of data points will be random.
"""
super(PrefetchData, self).__init__(ds)
self._size = self.size()
......@@ -45,7 +46,7 @@ class PrefetchData(ProxyDataFlow):
self.queue = multiprocessing.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)]
ensure_procs_terminate(self.procs)
ensure_proc_terminate(self.procs)
for x in self.procs:
x.start()
......
......@@ -7,9 +7,13 @@ from itertools import count
import argparse
from collections import namedtuple
import numpy as np
import bisect
from tqdm import tqdm
from six.moves import zip
import multiprocessing
from .utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from .tfutils import *
from .utils import logger
from .tfutils.modelutils import describe_model
......@@ -50,6 +54,7 @@ class PredictConfig(object):
:param output_var_names: a list of names of the output variables to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param nr_gpu: default to 1. Use CUDA_VISIBLE_DEVICES to control which GPU to use sepcifically.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
......@@ -59,6 +64,7 @@ class PredictConfig(object):
self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None)
self.output_var_names = kwargs.pop('output_var_names')
self.nr_gpu = kwargs.pop('nr_gpu', 1)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def get_predict_func(config):
......@@ -81,8 +87,6 @@ def get_predict_func(config):
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names]
describe_model()
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
......@@ -101,27 +105,103 @@ def get_predict_func(config):
PredictResult = namedtuple('PredictResult', ['input', 'output'])
# TODO mutligpu predictor
class PredictWorker(multiprocessing.Process):
def __init__(self, idx, gpuid, inqueue, outqueue, config):
super(PredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.inqueue = inqueue
self.outqueue = outqueue
self.config = config
def run(self):
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:{}'.format(self.idx)):
self.func = get_predict_func(self.config)
if self.idx == 0:
describe_model()
while True:
tid, dp = self.inqueue.get()
if tid == DIE:
self.outqueue.put((DIE, None))
return
else:
res = PredictResult(dp, self.func(dp))
self.outqueue.put((tid, res))
def DFtoQueue(ds, size, nr_consumer):
q = multiprocessing.Queue(size)
class EnqueProc(multiprocessing.Process):
def __init__(self, ds, q, nr_consumer):
super(EnqueProc, self).__init__()
self.ds = ds
self.q = q
def run(self):
for idx, dp in enumerate(self.ds.get_data()):
self.q.put((idx, dp))
print "Enqueue ends"
for _ in range(nr_consumer):
self.q.put((DIE, None))
proc = EnqueProc(ds, q, nr_consumer)
return q, proc
class DatasetPredictor(object):
"""
Run the predict_config on a given `DataFlow`.
"""
def __init__(self, predict_config, dataset):
def __init__(self, config, dataset):
"""
:param predict_config: a `PredictConfig` instance.
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
"""
assert isinstance(dataset, DataFlow)
self.ds = dataset
self.predict_func = get_predict_func(predict_config)
self.nr_gpu = config.nr_gpu
if self.nr_gpu > 1:
self.inqueue, self.inqueue_proc = DFtoQueue(self.ds, 10, self.nr_gpu)
self.outqueue = multiprocessing.Queue()
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
except KeyError:
gpus = range(self.nr_gpu)
self.workers = [PredictWorker(i, gpus[i], self.inqueue, self.outqueue, config)
for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue)
# run the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers)
ensure_proc_terminate([self.result_queue, self.inqueue_proc])
else:
self.func = get_predict_func(config)
def get_result(self):
""" A generator to produce prediction for each data"""
with tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
yield PredictResult(dp, self.predict_func(dp))
pbar.update()
if self.nr_gpu == 1:
for dp in self.ds.get_data():
yield PredictResult(dp, self.func(dp))
pbar.update()
else:
while True:
res = self.result_queue.get()
if res[0] != DIE:
yield res[1]
else:
break
pbar.update()
self.inqueue_proc.join()
self.inqueue_proc.terminate()
for p in self.workers:
p.join(); p.terminate()
def get_all_result(self):
"""
......
......@@ -116,9 +116,11 @@ class QueueInputTrainer(Trainer):
# get gradients to update:
if self.config.nr_tower > 1:
logger.info("Training a model of {} tower".format(self.config.nr_tower))
# to avoid repeated summary from each device
coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {}
grad_list = []
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
......
# -*- coding: UTF-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Credit belongs to Xinyu Zhou
import threading
import multiprocessing
import multiprocessing, multiprocess
from contextlib import contextmanager
import tensorflow as tf
import atexit
......@@ -12,6 +13,9 @@ from six.moves import zip
from .naming import *
__all__ = ['StoppableThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
class StoppableThread(threading.Thread):
def __init__(self):
super(StoppableThread, self).__init__()
......@@ -24,7 +28,16 @@ class StoppableThread(threading.Thread):
return self._stop.isSet()
class DIE(object):
pass
def ensure_proc_terminate(proc):
if isinstance(proc, list):
for p in proc:
ensure_proc_terminate(p)
return
def stop_proc_by_weak_ref(ref):
proc = ref()
if proc is None:
......@@ -34,9 +47,58 @@ def ensure_proc_terminate(proc):
proc.terminate()
proc.join()
assert isinstance(proc, multiprocessing.Process)
assert isinstance(proc, (multiprocessing.Process, multiprocess.Process))
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
def ensure_procs_terminate(procs):
for p in procs:
ensure_proc_terminate(p)
class OrderedContainer(object):
def __init__(self, start=0):
self.ranks = []
self.data = []
self.wait_for = start
def put(self, rank, val):
idx = bisect.bisect(self.ranks, rank)
self.ranks.insert(idx, rank)
self.data.insert(idx, val)
def has_next(self):
if len(self.ranks) == 0:
return False
return self.ranks[0] == self.wait_for
def get(self):
assert self.has_next()
ret = self.data[0]
rank = self.ranks[0]
del self.ranks[0]
del self.data[0]
self.wait_for += 1
return rank, ret
class OrderedResultGatherProc(multiprocessing.Process):
def __init__(self, data_queue, start=0):
super(self.__class__, self).__init__()
self.data_queue = data_queue
self.ordered_container = OrderedContainer(start=start)
self.result_queue = multiprocessing.Queue()
def run(self):
try:
while True:
task_id, data = self.data_queue.get()
if task_id == DIE:
self.result_queue.put((task_id, data))
else:
self.ordered_container.put(task_id, data)
while self.ordered_container.has_next():
self.result_queue.put(self.ordered_container.get())
except Exception as e:
import traceback
traceback.print_exc()
raise e
def get(self):
return self.result_queue.get()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment