Commit 80722088 authored by Yuxin Wu's avatar Yuxin Wu

initial version of multithread predictor

parent 5ccaea83
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
import multiprocessing, threading import multiprocessing, threading
import tensorflow as tf import tensorflow as tf
from six.moves import queue, range
from ..utils.concurrency import DIE from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..utils import logger from ..utils import logger
...@@ -12,10 +15,20 @@ from ..tfutils import * ...@@ -12,10 +15,20 @@ from ..tfutils import *
from .common import * from .common import *
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker'] try:
if six.PY2:
from tornado.concurrent import Future
else:
from concurrent.futures import Future
except ImportError:
logger.warn("Cannot import Future in either tornado.concurrent or py3 standard lib. MultiThreadAsyncPredictor won't be available.")
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
else:
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor']
class MultiProcessPredictWorker(multiprocessing.Process): class MultiProcessPredictWorker(multiprocessing.Process):
""" Base class for predict worker that runs in multiprocess""" """ Base class for predict worker that runs offline in multiprocess"""
def __init__(self, idx, gpuid, config): def __init__(self, idx, gpuid, config):
""" """
:param idx: index of the worker. the 0th worker will print log. :param idx: index of the worker. the 0th worker will print log.
...@@ -44,7 +57,7 @@ class MultiProcessPredictWorker(multiprocessing.Process): ...@@ -44,7 +57,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
describe_model() describe_model()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" A worker process to run predictor on one GPU """ """ A predictor worker that takes input and produces output by queue"""
def __init__(self, idx, gpuid, inqueue, outqueue, config): def __init__(self, idx, gpuid, inqueue, outqueue, config):
""" """
:param inqueue: input queue to get data point. elements are (task_id, dp) :param inqueue: input queue to get data point. elements are (task_id, dp)
...@@ -64,17 +77,40 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -64,17 +77,40 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else: else:
self.outqueue.put((tid, self.func(dp))) self.outqueue.put((tid, self.func(dp)))
#class CurrentSessionPredictor(): class PerdictorWorkerThread(threading.Thread):
#def __init__(self, idx, gpuid, config): def __init__(self, queue, pred_func):
#""" super(PredictorWorkerThread, self).__init__()
#:param idx: index of the worker. the 0th worker will print log. self.queue = queue
#:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU. self.func = pred_func
#:param config: a `PredictConfig` self.daemon = True
#"""
#super(MultiProcessPredictWorker, self).__init__() def run(self):
#self.idx = idx while True:
#self.gpuid = gpuid inputs, f = self.queue.get()
#self.config = config outputs = self.func(inputs)
f.set_result(outputs)
class MultiThreadAsyncPredictor(object):
"""
An online predictor (use the current active session) that works with
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU.
"""
def __init__(self, trainer, input_names, output_names, nr_thread):
"""
:param trainer: a `QueueInputTrainer` instance.
"""
self.input_queue = queue.Queue()
self.threads = [PredictorWorkerThread(self.input_queue, f)
for f in trainer.get_predict_funcs(input_names, output_names, nr_thread)]
def run(self):
for t in self.threads:
t.start()
#def run(self): def put_task(self, inputs, callback=None):
#pass """ return a Future of output."""
f = Future()
self.input_queue.put((inputs, f))
if callback is not None:
f.add_done_callback(callback)
return f
...@@ -52,9 +52,17 @@ class Trainer(object): ...@@ -52,9 +52,17 @@ class Trainer(object):
@abstractmethod @abstractmethod
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
""" return a predict function""" """ return a predictor function"""
pass pass
def get_predict_funcs(self, input_names, output_names, n):
""" return n predictor functions.
Can be overwritten by subclasses to exploit more
parallelism among funcs.
"""
return [self.get_predict_func(input_name, output_names)
for k in range(n)]
def trigger_epoch(self): def trigger_epoch(self):
self._trigger_epoch() self._trigger_epoch()
self.config.callbacks.trigger_epoch() self.config.callbacks.trigger_epoch()
......
...@@ -117,7 +117,7 @@ class QueueInputTrainer(Trainer): ...@@ -117,7 +117,7 @@ class QueueInputTrainer(Trainer):
self.async = async self.async = async
if self.async: if self.async:
assert self.config.nr_tower > 1 assert self.config.nr_tower > 1
self._dequed_inputs = [] self.dequed_inputs = []
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
...@@ -140,7 +140,7 @@ class QueueInputTrainer(Trainer): ...@@ -140,7 +140,7 @@ class QueueInputTrainer(Trainer):
assert len(ret) == len(self.input_vars) assert len(ret) == len(self.input_vars)
for qv, v in zip(ret, self.input_vars): for qv, v in zip(ret, self.input_vars):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
self._dequed_inputs.append(ret) self.dequed_inputs.append(ret)
return ret return ret
def _single_tower_grad(self): def _single_tower_grad(self):
...@@ -241,27 +241,31 @@ class QueueInputTrainer(Trainer): ...@@ -241,27 +241,31 @@ class QueueInputTrainer(Trainer):
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self._process_summary(summary_str) self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names, tower=0):
"""
:param tower: return the kth predict_func
"""
tower = tower % self.config.nr_tower
logger.info("Prepare a predictor function for tower{} ...".format(tower))
raw_input_vars = get_vars_by_names(input_names) raw_input_vars = get_vars_by_names(input_names)
input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars] input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars]
if self.config.nr_tower == 1: dequed = self.dequed_inputs[tower]
dequed = self._dequed_inputs[0] input_vars = [dequed[k] for k in input_var_idxs]
input_vars = [dequed[k] for k in input_var_idxs]
output_vars = get_vars_by_names(output_names)
else:
# TODO naive impl: use the first tower only
dequed = self._dequed_inputs[0]
input_vars = [dequed[k] for k in input_var_idxs]
output_names = ['tower0/' + n for n in output_names]
output_vars = get_vars_by_names(output_names)
if self.config.nr_tower > 1:
output_names = ['tower{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
def func(inputs): def func(inputs):
assert len(inputs) == len(input_vars) assert len(inputs) == len(input_vars)
feed = dict(zip(input_vars, inputs)) feed = dict(zip(input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed) return self.sess.run(output_vars, feed_dict=feed)
return func return func
def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_name, output_names, k)
for k in range(n)]
def start_train(config): def start_train(config):
tr = QueueInputTrainer(config) tr = QueueInputTrainer(config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment