Commit 49675590 authored by Yuxin Wu's avatar Yuxin Wu

fix use of nr_tower (fix #1077)

parent 1097672b
...@@ -206,6 +206,7 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -206,6 +206,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def train(): def train():
assert tf.test.is_gpu_available(), "Training requires GPUs!"
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME)) dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname) logger.set_logger_dir(dirname)
...@@ -259,7 +260,7 @@ def train(): ...@@ -259,7 +260,7 @@ def train():
session_init=get_model_loader(args.load) if args.load else None, session_init=get_model_loader(args.load) if args.load else None,
max_epoch=1000, max_epoch=1000,
) )
trainer = SimpleTrainer() if config.nr_tower == 1 else AsyncMultiGPUTrainer(train_tower) trainer = SimpleTrainer() if num_gpu == 1 else AsyncMultiGPUTrainer(train_tower)
launch_train_with_config(config, trainer) launch_train_with_config(config, trainer)
......
...@@ -71,8 +71,6 @@ if __name__ == '__main__': ...@@ -71,8 +71,6 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
......
...@@ -104,12 +104,12 @@ class Trainer(object): ...@@ -104,12 +104,12 @@ class Trainer(object):
The ``tf.Session`` object the trainer is using. The ``tf.Session`` object the trainer is using.
Available after :meth:`initialize()`. Available after :meth:`initialize()`.
Using ``trainer.sess.run`` to evaluate tensors that depend on the inputs Using ``trainer.sess.run`` to evaluate tensors that depend on the training
can lead to unexpected effect: ``InputSource`` may have unexpected effect:
For example, if you use ``trainer.sess.run`` to evaluate a tensor that depends on the For example, if you use ``trainer.sess.run`` to evaluate a tensor that depends on the
inputs coming from a ``StagingArea``, inputs coming from a ``StagingArea``,
this will take a datapoint from the ``StagingArea``, making the ``StagingArea`` empty, and as a result it will take a datapoint from the ``StagingArea``, making the ``StagingArea`` empty, and as a result
make the training hang. make the training hang.
""" """
......
...@@ -137,7 +137,7 @@ class AsyncMultiGPUTrainer(SingleCostTrainer): ...@@ -137,7 +137,7 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
if len(self.devices) > 1: if len(self.devices) > 1:
assert isinstance(input, FeedfreeInput), input assert isinstance(input, FeedfreeInput), input
tower_fn = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), tower_fn = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)
grad_list = self._builder.call_for_each_tower(tower_fn) grad_list = self._builder.call_for_each_tower(tower_fn)
self.train_op = self._builder.build(grad_list, get_opt_fn) self.train_op = self._builder.build(grad_list, get_opt_fn)
return [] return []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment