Commit 4fc21080 authored by Yuxin Wu's avatar Yuxin Wu

async predictor base

parent e04d846a
Reproduce the following methods:
Reproduce the following reinforcement learning methods:
+ Nature-DQN in:
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
......
......@@ -5,9 +5,10 @@
from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf
import six
from ..tfutils import get_vars_by_names
__all__ = ['OnlinePredictor', 'OfflinePredictor']
__all__ = ['OnlinePredictor', 'OfflinePredictor', 'AsyncPredictorBase']
class PredictorBase(object):
......@@ -31,7 +32,27 @@ class PredictorBase(object):
:param dp: input datapoint. must have the same length as input_var_names
:return: output as defined by the config
"""
pass
class AsyncPredictorBase(PredictorBase):
@abstractmethod
def put_task(self, dp, callback=None):
"""
:param dp: A data point (list of component) as inputs.
(It should be either batched or not batched depending on the predictor implementation)
:param callback: a thread-safe callback to get called with the list of
outputs of (inputs, outputs) pair
:return: a Future of outputs
"""
@abstractmethod
def start(self):
""" Start workers """
def _do_call(self, dp):
assert six.PY3, "With Python2, sync methods not available for async predictor"
fut = self.put_task(dp)
# in Tornado, Future.result() doesn't wait
return fut.result()
class OnlinePredictor(PredictorBase):
def __init__(self, sess, input_vars, output_vars, return_input=False):
......@@ -64,3 +85,19 @@ class OfflinePredictor(OnlinePredictor):
config.session_init.init(sess)
super(OfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)
class AsyncOnlinePredictor(PredictorBase):
def __init__(self, sess, enqueue_op, output_vars, return_input=False):
"""
:param enqueue_op: an op to feed inputs with.
:param output_vars: a list of directly-runnable (no extra feeding requirements)
vars producing the outputs.
"""
self.session = sess
self.enqop = enqueue_op
self.output_vars = output_vars
self.return_input = return_input
def put_task(self, dp, callback):
pass
......@@ -16,7 +16,7 @@ from ..utils import logger
from ..utils.timer import *
from ..tfutils import *
from .base import OfflinePredictor
from .base import *
try:
if six.PY2:
......@@ -116,34 +116,39 @@ class PredictorWorkerThread(threading.Thread):
cnt += 1
return batched, futures
class MultiThreadAsyncPredictor(object):
class MultiThreadAsyncPredictor(AsyncPredictorBase):
"""
An multithread predictor which run a list of predict func.
Use async interface, support multi-thread and multi-GPU.
An multithread online async predictor which run a list of OnlinePredictor.
It would do an extra batching internally.
"""
def __init__(self, funcs, batch_size=5):
""" :param funcs: a list of predict func"""
self.input_queue = queue.Queue(maxsize=len(funcs)*10)
def __init__(self, predictors, batch_size=5):
""" :param predictors: a list of OnlinePredictor"""
for k in predictors:
assert isinstance(k, OnlinePredictor), type(k)
self.input_queue = queue.Queue(maxsize=len(predictors)*10)
self.threads = [
PredictorWorkerThread(
self.input_queue, f, id, batch_size=batch_size)
for id, f in enumerate(funcs)]
for id, f in enumerate(predictors)]
# TODO XXX set logging here to avoid affecting TF logging
import tornado.options as options
options.parse_command_line(['--logging=debug'])
if six.PY2:
# TODO XXX set logging here to avoid affecting TF logging
import tornado.options as options
options.parse_command_line(['--logging=debug'])
def run(self):
def start(self):
for t in self.threads:
t.start()
def put_task(self, inputs, callback=None):
def run(self): # temporarily for back-compatibility
self.start()
def put_task(self, dp, callback=None):
"""
dp must be non-batched, i.e. single instance
"""
:param inputs: a data point (list of component) matching input_names (not batched)
:param callback: a thread-safe callback to get called with the list of outputs
:returns: a Future of output."""
f = Future()
if callback is not None:
f.add_done_callback(callback)
self.input_queue.put((inputs, f))
self.input_queue.put((dp, f))
return f
......@@ -62,7 +62,7 @@ class Trainer(object):
Can be overwritten by subclasses to exploit more
parallelism among funcs.
"""
return [self.get_predict_func(input_name, output_names) for k in range(n)]
return [self.get_predict_func(input_names, output_names) for k in range(n)]
def trigger_epoch(self):
self._trigger_epoch()
......
......@@ -30,6 +30,7 @@ class PredictorFactory(object):
self.tower_built = False
def get_predictor(self, input_names, output_names, tower):
""" Return an online predictor"""
if not self.tower_built:
self._build_predict_tower()
tower = self.towers[tower % len(self.towers)]
......@@ -204,7 +205,7 @@ class QueueInputTrainer(Trainer):
self.main_loop()
def run_step(self):
""" just run self.train_op"""
""" Simply run self.train_op"""
self.sess.run(self.train_op)
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment