Commit 64a63c5e authored by Yuxin Wu's avatar Yuxin Wu

multi tower prediction graph

parent af2c0e9c
...@@ -21,7 +21,7 @@ Both were trained on one GPU with an extra GPU for simulation. ...@@ -21,7 +21,7 @@ Both were trained on one GPU with an extra GPU for simulation.
This is probably the fastest RL trainer you'd find. This is probably the fastest RL trainer you'd find.
The x-axis is the number of iterations, not wall time. The x-axis is the number of iterations, not wall time.
Iteration speed on Tesla M40 is about 10.7it/s for B-A3C. Iteration speed on Tesla M40 is about 9.7it/s for B-A3C.
D-DQN is faster at the beginning but will converge to 12it/s due of exploration annealing. D-DQN is faster at the beginning but will converge to 12it/s due of exploration annealing.
A demo trained with Double-DQN on breakout is available at [youtube](https://youtu.be/o21mddZtE5Y). A demo trained with Double-DQN on breakout is available at [youtube](https://youtu.be/o21mddZtE5Y).
......
...@@ -21,6 +21,7 @@ class PreventStuckPlayer(ProxyPlayer): ...@@ -21,6 +21,7 @@ class PreventStuckPlayer(ProxyPlayer):
""" """
:param nr_repeat: trigger the 'action' after this many of repeated action :param nr_repeat: trigger the 'action' after this many of repeated action
:param action: the action to be triggered to get out of stuck :param action: the action to be triggered to get out of stuck
Does auto-reset, but doesn't auto-restart the underlying player.
""" """
super(PreventStuckPlayer, self).__init__(player) super(PreventStuckPlayer, self).__init__(player)
self.act_que = deque(maxlen=nr_repeat) self.act_que = deque(maxlen=nr_repeat)
...@@ -41,7 +42,7 @@ class PreventStuckPlayer(ProxyPlayer): ...@@ -41,7 +42,7 @@ class PreventStuckPlayer(ProxyPlayer):
class LimitLengthPlayer(ProxyPlayer): class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode. """ Limit the total number of actions in an episode.
Does not auto restart. Does auto-reset, but doesn't auto-restart the underlying player.
""" """
def __init__(self, player, limit): def __init__(self, player, limit):
super(LimitLengthPlayer, self).__init__(player) super(LimitLengthPlayer, self).__init__(player)
...@@ -53,6 +54,8 @@ class LimitLengthPlayer(ProxyPlayer): ...@@ -53,6 +54,8 @@ class LimitLengthPlayer(ProxyPlayer):
self.cnt += 1 self.cnt += 1
if self.cnt >= self.limit: if self.cnt >= self.limit:
isOver = True isOver = True
if isOver:
self.cnt = 0
return (r, isOver) return (r, isOver)
def restart_episode(self): def restart_episode(self):
......
...@@ -71,7 +71,8 @@ class ModelDesc(object): ...@@ -71,7 +71,8 @@ class ModelDesc(object):
def get_gradient_processor(self): def get_gradient_processor(self):
""" Return a list of GradientProcessor. They will be executed in order""" """ Return a list of GradientProcessor. They will be executed in order"""
return [CheckGradient()]#, SummaryGradient()] return [#SummaryGradient(),
CheckGradient()]
class ModelFromMetaGraph(ModelDesc): class ModelFromMetaGraph(ModelDesc):
......
...@@ -6,9 +6,13 @@ ...@@ -6,9 +6,13 @@
from abc import abstractmethod, ABCMeta, abstractproperty from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf import tensorflow as tf
import six import six
from ..utils import logger
from ..tfutils import get_vars_by_names from ..tfutils import get_vars_by_names
__all__ = ['OnlinePredictor', 'OfflinePredictor', 'AsyncPredictorBase'] __all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph']
class PredictorBase(object): class PredictorBase(object):
...@@ -87,17 +91,43 @@ class OfflinePredictor(OnlinePredictor): ...@@ -87,17 +91,43 @@ class OfflinePredictor(OnlinePredictor):
sess, input_vars, output_vars, config.return_input) sess, input_vars, output_vars, config.return_input)
#class AsyncOnlinePredictor(PredictorBase): def build_multi_tower_prediction_graph(model, towers, prefix='towerp'):
#def __init__(self, sess, enqueue_op, output_vars, return_input=False): """
#""" :param towers: a list of gpu relative id.
#:param enqueue_op: an op to feed inputs with. """
#:param output_vars: a list of directly-runnable (no extra feeding requirements) input_vars = model.get_input_vars()
#vars producing the outputs. for k in towers:
#""" logger.info(
#self.session = sess "Building graph for predictor tower {}...".format(k))
#self.enqop = enqueue_op with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'),\
#self.output_vars = output_vars tf.name_scope('{}{}'.format(prefix, k)):
#self.return_input = return_input model._build_graph(input_vars, False)
tf.get_variable_scope().reuse_variables()
#def put_task(self, dp, callback):
#pass def MultiTowerOfflinePredictor(OnlinePredictor):
PREFIX = 'towerp'
def __init__(self, config, towers):
self.graph = tf.Graph()
self.predictors = []
with self.graph.as_default():
# TODO backup summary keys?
build_multi_tower_prediction_graph(config.model, towers, self.PREFIX)
self.sess = tf.Session(config=config.session_config)
config.session_init.init(self.sess)
input_vars = get_vars_by_names(config.input_var_names)
# use the first tower for compatible PredictorBase interface
for k in towers:
output_vars = get_vars_by_names(
['{}{}/'.format(self.PREFIX, k) + n \
for n in config.output_var_names])
self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input))
def _do_call(self, dp):
return self.predictors[0]._do_call(dp)
def get_predictors(self, n):
return [self.predictors[k % len(self.predictors)] for k in range(n)]
...@@ -42,6 +42,9 @@ class MultiProcessPredictWorker(multiprocessing.Process): ...@@ -42,6 +42,9 @@ class MultiProcessPredictWorker(multiprocessing.Process):
self.config = config self.config = config
def _init_runtime(self): def _init_runtime(self):
""" Call _init_runtime under different CUDA_VISIBLE_DEVICES, you'll
have workers that run on multiGPUs
"""
if self.idx != 0: if self.idx != 0:
from tensorpack.models._common import disable_layer_logging from tensorpack.models._common import disable_layer_logging
disable_layer_logging() disable_layer_logging()
...@@ -72,6 +75,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -72,6 +75,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else: else:
self.outqueue.put((tid, self.func(dp))) self.outqueue.put((tid, self.func(dp)))
class PredictorWorkerThread(threading.Thread): class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id, batch_size=5): def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__() super(PredictorWorkerThread, self).__init__()
...@@ -118,13 +122,13 @@ class PredictorWorkerThread(threading.Thread): ...@@ -118,13 +122,13 @@ class PredictorWorkerThread(threading.Thread):
class MultiThreadAsyncPredictor(AsyncPredictorBase): class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" """
An multithread online async predictor which run a list of OnlinePredictor. An multithread online async predictor which run a list of PredictorBase.
It would do an extra batching internally. It would do an extra batching internally.
""" """
def __init__(self, predictors, batch_size=5): def __init__(self, predictors, batch_size=5):
""" :param predictors: a list of OnlinePredictor""" """ :param predictors: a list of OnlinePredictor"""
for k in predictors: for k in predictors:
assert isinstance(k, OnlinePredictor), type(k) #assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here # TODO use predictors.return_input here
assert k.return_input == False assert k.return_input == False
self.input_queue = queue.Queue(maxsize=len(predictors)*100) self.input_queue = queue.Queue(maxsize=len(predictors)*100)
......
...@@ -48,16 +48,19 @@ class SummaryGradient(GradientProcessor): ...@@ -48,16 +48,19 @@ class SummaryGradient(GradientProcessor):
name=name + '/RMS')) name=name + '/RMS'))
return grads return grads
class CheckGradient(GradientProcessor): class CheckGradient(GradientProcessor):
""" """
Check for numeric issue Check for numeric issue
""" """
def _process(self, grads): def _process(self, grads):
ret = []
for grad, var in grads: for grad, var in grads:
# TODO make assert work op = tf.Assert(tf.reduce_all(tf.is_finite(var)),
tf.Assert(tf.reduce_all(tf.is_finite(var)), [var]) [var], summarize=100)
return grads with tf.control_dependencies([op]):
grad = tf.identity(grad)
ret.append((grad, var))
return ret
class ScaleGradient(GradientProcessor): class ScaleGradient(GradientProcessor):
""" """
......
...@@ -16,7 +16,7 @@ from ..tfutils.modelutils import describe_model ...@@ -16,7 +16,7 @@ from ..tfutils.modelutils import describe_model
from ..utils import * from ..utils import *
from ..tfutils import * from ..tfutils import *
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..predict import OnlinePredictor from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
__all__ = ['SimpleTrainer', 'QueueInputTrainer'] __all__ = ['SimpleTrainer', 'QueueInputTrainer']
...@@ -24,32 +24,35 @@ class PredictorFactory(object): ...@@ -24,32 +24,35 @@ class PredictorFactory(object):
""" Make predictors for a trainer""" """ Make predictors for a trainer"""
PREFIX = 'towerp' PREFIX = 'towerp'
def __init__(self, trainer, towers): def __init__(self, sess, model, towers):
self.trainer = trainer """
:param towers: list of gpu relative id
"""
self.sess = sess
self.model = model
self.towers = towers self.towers = towers
self.tower_built = False self.tower_built = False
def get_predictor(self, input_names, output_names, tower): def get_predictor(self, input_names, output_names, tower):
""" Return an online predictor""" """
:param tower: need the kth tower (not the gpu id)
:returns: an online predictor
"""
if not self.tower_built: if not self.tower_built:
self._build_predict_tower() self._build_predict_tower()
tower = self.towers[tower % len(self.towers)] tower = self.towers[tower % len(self.towers)]
raw_input_vars = get_vars_by_names(input_names) raw_input_vars = get_vars_by_names(input_names)
output_names = ['{}{}/'.format(self.PREFIX, tower) + n for n in output_names] output_names = ['{}{}/'.format(self.PREFIX, tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names) output_vars = get_vars_by_names(output_names)
return OnlinePredictor(self.trainer.sess, raw_input_vars, output_vars) return OnlinePredictor(self.sess, raw_input_vars, output_vars)
def _build_predict_tower(self): def _build_predict_tower(self):
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope # build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope
tf.get_variable_scope().reuse_variables()
with tf.name_scope(None), \ with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS): freeze_collection(SUMMARY_BACKUP_KEYS):
inputs = self.trainer.model.get_input_vars() build_multi_tower_prediction_graph(
tf.get_variable_scope().reuse_variables() self.model, self.towers, prefix=self.PREFIX)
for k in self.towers:
logger.info("Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
tf.name_scope('{}{}'.format(self.PREFIX, k)):
self.trainer.model.build_graph(inputs, False)
self.tower_built = True self.tower_built = True
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
...@@ -89,7 +92,7 @@ class SimpleTrainer(Trainer): ...@@ -89,7 +92,7 @@ class SimpleTrainer(Trainer):
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
if not hasattr(self, 'predictor_factory'): if not hasattr(self, 'predictor_factory'):
self.predictor_factory = PredictorFactory(self, [0]) self.predictor_factory = PredictorFactory(self.sess, self.model, [0])
return self.predictor_factory.get_predictor(input_names, output_names, 0) return self.predictor_factory.get_predictor(input_names, output_names, 0)
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
...@@ -150,11 +153,8 @@ class QueueInputTrainer(Trainer): ...@@ -150,11 +153,8 @@ class QueueInputTrainer(Trainer):
else: else:
self.input_queue = input_queue self.input_queue = input_queue
if predict_tower is None:
# by default, use the first training gpu for prediction # by default, use the first training gpu for prediction
predict_tower = [0] self.predict_tower = predict_tower or [0]
self.predictor_factory = PredictorFactory(self, predict_tower)
self.dequed_inputs = None self.dequed_inputs = None
def _get_model_inputs(self): def _get_model_inputs(self):
...@@ -233,6 +233,9 @@ class QueueInputTrainer(Trainer): ...@@ -233,6 +233,9 @@ class QueueInputTrainer(Trainer):
:param tower: return the kth predict_func :param tower: return the kth predict_func
:returns: an `OnlinePredictor` :returns: an `OnlinePredictor`
""" """
if not hasattr(self, 'predictor_factory'):
self.predictor_factory = PredictorFactory(
self.sess, self.model, self.predict_tower)
return self.predictor_factory.get_predictor(input_names, output_names, tower) return self.predictor_factory.get_predictor(input_names, output_names, tower)
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment