Commit 80722088 authored by Yuxin Wu's avatar Yuxin Wu

initial version of multithread predictor

parent 5ccaea83
......@@ -5,6 +5,9 @@
import multiprocessing, threading
import tensorflow as tf
from six.moves import queue, range
from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model
from ..utils import logger
......@@ -12,10 +15,20 @@ from ..tfutils import *
from .common import *
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
try:
if six.PY2:
from tornado.concurrent import Future
else:
from concurrent.futures import Future
except ImportError:
logger.warn("Cannot import Future in either tornado.concurrent or py3 standard lib. MultiThreadAsyncPredictor won't be available.")
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
else:
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor']
class MultiProcessPredictWorker(multiprocessing.Process):
""" Base class for predict worker that runs in multiprocess"""
""" Base class for predict worker that runs offline in multiprocess"""
def __init__(self, idx, gpuid, config):
"""
:param idx: index of the worker. the 0th worker will print log.
......@@ -44,7 +57,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
describe_model()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" A worker process to run predictor on one GPU """
""" A predictor worker that takes input and produces output by queue"""
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
......@@ -64,17 +77,40 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else:
self.outqueue.put((tid, self.func(dp)))
#class CurrentSessionPredictor():
#def __init__(self, idx, gpuid, config):
#"""
#:param idx: index of the worker. the 0th worker will print log.
#:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
#:param config: a `PredictConfig`
#"""
#super(MultiProcessPredictWorker, self).__init__()
#self.idx = idx
#self.gpuid = gpuid
#self.config = config
class PerdictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func):
super(PredictorWorkerThread, self).__init__()
self.queue = queue
self.func = pred_func
self.daemon = True
def run(self):
while True:
inputs, f = self.queue.get()
outputs = self.func(inputs)
f.set_result(outputs)
class MultiThreadAsyncPredictor(object):
"""
An online predictor (use the current active session) that works with
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU.
"""
def __init__(self, trainer, input_names, output_names, nr_thread):
"""
:param trainer: a `QueueInputTrainer` instance.
"""
self.input_queue = queue.Queue()
self.threads = [PredictorWorkerThread(self.input_queue, f)
for f in trainer.get_predict_funcs(input_names, output_names, nr_thread)]
def run(self):
for t in self.threads:
t.start()
#def run(self):
#pass
def put_task(self, inputs, callback=None):
""" return a Future of output."""
f = Future()
self.input_queue.put((inputs, f))
if callback is not None:
f.add_done_callback(callback)
return f
......@@ -52,9 +52,17 @@ class Trainer(object):
@abstractmethod
def get_predict_func(self, input_names, output_names):
""" return a predict function"""
""" return a predictor function"""
pass
def get_predict_funcs(self, input_names, output_names, n):
""" return n predictor functions.
Can be overwritten by subclasses to exploit more
parallelism among funcs.
"""
return [self.get_predict_func(input_name, output_names)
for k in range(n)]
def trigger_epoch(self):
self._trigger_epoch()
self.config.callbacks.trigger_epoch()
......
......@@ -117,7 +117,7 @@ class QueueInputTrainer(Trainer):
self.async = async
if self.async:
assert self.config.nr_tower > 1
self._dequed_inputs = []
self.dequed_inputs = []
@staticmethod
def _average_grads(tower_grads):
......@@ -140,7 +140,7 @@ class QueueInputTrainer(Trainer):
assert len(ret) == len(self.input_vars)
for qv, v in zip(ret, self.input_vars):
qv.set_shape(v.get_shape())
self._dequed_inputs.append(ret)
self.dequed_inputs.append(ret)
return ret
def _single_tower_grad(self):
......@@ -241,27 +241,31 @@ class QueueInputTrainer(Trainer):
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def get_predict_func(self, input_names, output_names):
def get_predict_func(self, input_names, output_names, tower=0):
"""
:param tower: return the kth predict_func
"""
tower = tower % self.config.nr_tower
logger.info("Prepare a predictor function for tower{} ...".format(tower))
raw_input_vars = get_vars_by_names(input_names)
input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars]
if self.config.nr_tower == 1:
dequed = self._dequed_inputs[0]
input_vars = [dequed[k] for k in input_var_idxs]
output_vars = get_vars_by_names(output_names)
else:
# TODO naive impl: use the first tower only
dequed = self._dequed_inputs[0]
dequed = self.dequed_inputs[tower]
input_vars = [dequed[k] for k in input_var_idxs]
output_names = ['tower0/' + n for n in output_names]
output_vars = get_vars_by_names(output_names)
if self.config.nr_tower > 1:
output_names = ['tower{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
def func(inputs):
assert len(inputs) == len(input_vars)
feed = dict(zip(input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_name, output_names, k)
for k in range(n)]
def start_train(config):
tr = QueueInputTrainer(config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment