Commit 64a63c5e authored by Yuxin Wu's avatar Yuxin Wu

multi tower prediction graph

parent af2c0e9c
......@@ -21,7 +21,7 @@ Both were trained on one GPU with an extra GPU for simulation.
This is probably the fastest RL trainer you'd find.
The x-axis is the number of iterations, not wall time.
Iteration speed on Tesla M40 is about 10.7it/s for B-A3C.
Iteration speed on Tesla M40 is about 9.7it/s for B-A3C.
D-DQN is faster at the beginning but will converge to 12it/s due of exploration annealing.
A demo trained with Double-DQN on breakout is available at [youtube](https://youtu.be/o21mddZtE5Y).
......
......@@ -21,6 +21,7 @@ class PreventStuckPlayer(ProxyPlayer):
"""
:param nr_repeat: trigger the 'action' after this many of repeated action
:param action: the action to be triggered to get out of stuck
Does auto-reset, but doesn't auto-restart the underlying player.
"""
super(PreventStuckPlayer, self).__init__(player)
self.act_que = deque(maxlen=nr_repeat)
......@@ -41,7 +42,7 @@ class PreventStuckPlayer(ProxyPlayer):
class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode.
Does not auto restart.
Does auto-reset, but doesn't auto-restart the underlying player.
"""
def __init__(self, player, limit):
super(LimitLengthPlayer, self).__init__(player)
......@@ -53,6 +54,8 @@ class LimitLengthPlayer(ProxyPlayer):
self.cnt += 1
if self.cnt >= self.limit:
isOver = True
if isOver:
self.cnt = 0
return (r, isOver)
def restart_episode(self):
......
......@@ -71,7 +71,8 @@ class ModelDesc(object):
def get_gradient_processor(self):
""" Return a list of GradientProcessor. They will be executed in order"""
return [CheckGradient()]#, SummaryGradient()]
return [#SummaryGradient(),
CheckGradient()]
class ModelFromMetaGraph(ModelDesc):
......
......@@ -6,9 +6,13 @@
from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf
import six
from ..utils import logger
from ..tfutils import get_vars_by_names
__all__ = ['OnlinePredictor', 'OfflinePredictor', 'AsyncPredictorBase']
__all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph']
class PredictorBase(object):
......@@ -87,17 +91,43 @@ class OfflinePredictor(OnlinePredictor):
sess, input_vars, output_vars, config.return_input)
#class AsyncOnlinePredictor(PredictorBase):
#def __init__(self, sess, enqueue_op, output_vars, return_input=False):
#"""
#:param enqueue_op: an op to feed inputs with.
#:param output_vars: a list of directly-runnable (no extra feeding requirements)
#vars producing the outputs.
#"""
#self.session = sess
#self.enqop = enqueue_op
#self.output_vars = output_vars
#self.return_input = return_input
#def put_task(self, dp, callback):
#pass
def build_multi_tower_prediction_graph(model, towers, prefix='towerp'):
"""
:param towers: a list of gpu relative id.
"""
input_vars = model.get_input_vars()
for k in towers:
logger.info(
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'),\
tf.name_scope('{}{}'.format(prefix, k)):
model._build_graph(input_vars, False)
tf.get_variable_scope().reuse_variables()
def MultiTowerOfflinePredictor(OnlinePredictor):
PREFIX = 'towerp'
def __init__(self, config, towers):
self.graph = tf.Graph()
self.predictors = []
with self.graph.as_default():
# TODO backup summary keys?
build_multi_tower_prediction_graph(config.model, towers, self.PREFIX)
self.sess = tf.Session(config=config.session_config)
config.session_init.init(self.sess)
input_vars = get_vars_by_names(config.input_var_names)
# use the first tower for compatible PredictorBase interface
for k in towers:
output_vars = get_vars_by_names(
['{}{}/'.format(self.PREFIX, k) + n \
for n in config.output_var_names])
self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input))
def _do_call(self, dp):
return self.predictors[0]._do_call(dp)
def get_predictors(self, n):
return [self.predictors[k % len(self.predictors)] for k in range(n)]
......@@ -42,6 +42,9 @@ class MultiProcessPredictWorker(multiprocessing.Process):
self.config = config
def _init_runtime(self):
""" Call _init_runtime under different CUDA_VISIBLE_DEVICES, you'll
have workers that run on multiGPUs
"""
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
......@@ -72,6 +75,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else:
self.outqueue.put((tid, self.func(dp)))
class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__()
......@@ -118,13 +122,13 @@ class PredictorWorkerThread(threading.Thread):
class MultiThreadAsyncPredictor(AsyncPredictorBase):
"""
An multithread online async predictor which run a list of OnlinePredictor.
An multithread online async predictor which run a list of PredictorBase.
It would do an extra batching internally.
"""
def __init__(self, predictors, batch_size=5):
""" :param predictors: a list of OnlinePredictor"""
for k in predictors:
assert isinstance(k, OnlinePredictor), type(k)
#assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
assert k.return_input == False
self.input_queue = queue.Queue(maxsize=len(predictors)*100)
......
......@@ -48,16 +48,19 @@ class SummaryGradient(GradientProcessor):
name=name + '/RMS'))
return grads
class CheckGradient(GradientProcessor):
"""
Check for numeric issue
"""
def _process(self, grads):
ret = []
for grad, var in grads:
# TODO make assert work
tf.Assert(tf.reduce_all(tf.is_finite(var)), [var])
return grads
op = tf.Assert(tf.reduce_all(tf.is_finite(var)),
[var], summarize=100)
with tf.control_dependencies([op]):
grad = tf.identity(grad)
ret.append((grad, var))
return ret
class ScaleGradient(GradientProcessor):
"""
......
......@@ -16,7 +16,7 @@ from ..tfutils.modelutils import describe_model
from ..utils import *
from ..tfutils import *
from ..tfutils.summary import add_moving_summary
from ..predict import OnlinePredictor
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
__all__ = ['SimpleTrainer', 'QueueInputTrainer']
......@@ -24,32 +24,35 @@ class PredictorFactory(object):
""" Make predictors for a trainer"""
PREFIX = 'towerp'
def __init__(self, trainer, towers):
self.trainer = trainer
def __init__(self, sess, model, towers):
"""
:param towers: list of gpu relative id
"""
self.sess = sess
self.model = model
self.towers = towers
self.tower_built = False
def get_predictor(self, input_names, output_names, tower):
""" Return an online predictor"""
"""
:param tower: need the kth tower (not the gpu id)
:returns: an online predictor
"""
if not self.tower_built:
self._build_predict_tower()
tower = self.towers[tower % len(self.towers)]
raw_input_vars = get_vars_by_names(input_names)
output_names = ['{}{}/'.format(self.PREFIX, tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names)
return OnlinePredictor(self.trainer.sess, raw_input_vars, output_vars)
return OnlinePredictor(self.sess, raw_input_vars, output_vars)
def _build_predict_tower(self):
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope
tf.get_variable_scope().reuse_variables()
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
inputs = self.trainer.model.get_input_vars()
tf.get_variable_scope().reuse_variables()
for k in self.towers:
logger.info("Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
tf.name_scope('{}{}'.format(self.PREFIX, k)):
self.trainer.model.build_graph(inputs, False)
build_multi_tower_prediction_graph(
self.model, self.towers, prefix=self.PREFIX)
self.tower_built = True
class SimpleTrainer(Trainer):
......@@ -89,7 +92,7 @@ class SimpleTrainer(Trainer):
def get_predict_func(self, input_names, output_names):
if not hasattr(self, 'predictor_factory'):
self.predictor_factory = PredictorFactory(self, [0])
self.predictor_factory = PredictorFactory(self.sess, self.model, [0])
return self.predictor_factory.get_predictor(input_names, output_names, 0)
class EnqueueThread(threading.Thread):
......@@ -150,11 +153,8 @@ class QueueInputTrainer(Trainer):
else:
self.input_queue = input_queue
if predict_tower is None:
# by default, use the first training gpu for prediction
predict_tower = [0]
self.predictor_factory = PredictorFactory(self, predict_tower)
# by default, use the first training gpu for prediction
self.predict_tower = predict_tower or [0]
self.dequed_inputs = None
def _get_model_inputs(self):
......@@ -233,6 +233,9 @@ class QueueInputTrainer(Trainer):
:param tower: return the kth predict_func
:returns: an `OnlinePredictor`
"""
if not hasattr(self, 'predictor_factory'):
self.predictor_factory = PredictorFactory(
self.sess, self.model, self.predict_tower)
return self.predictor_factory.get_predictor(input_names, output_names, tower)
def get_predict_funcs(self, input_names, output_names, n):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment