Commit 49675590 authored by Yuxin Wu's avatar Yuxin Wu

fix use of nr_tower (fix #1077)

parent 1097672b
......@@ -206,6 +206,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def train():
assert tf.test.is_gpu_available(), "Training requires GPUs!"
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
......@@ -259,7 +260,7 @@ def train():
session_init=get_model_loader(args.load) if args.load else None,
max_epoch=1000,
)
trainer = SimpleTrainer() if config.nr_tower == 1 else AsyncMultiGPUTrainer(train_tower)
trainer = SimpleTrainer() if num_gpu == 1 else AsyncMultiGPUTrainer(train_tower)
launch_train_with_config(config, trainer)
......
......@@ -71,8 +71,6 @@ if __name__ == '__main__':
config = get_config()
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
if args.load:
config.session_init = SaverRestore(args.load)
......
......@@ -104,12 +104,12 @@ class Trainer(object):
The ``tf.Session`` object the trainer is using.
Available after :meth:`initialize()`.
Using ``trainer.sess.run`` to evaluate tensors that depend on the inputs
can lead to unexpected effect:
Using ``trainer.sess.run`` to evaluate tensors that depend on the training
``InputSource`` may have unexpected effect:
For example, if you use ``trainer.sess.run`` to evaluate a tensor that depends on the
inputs coming from a ``StagingArea``,
this will take a datapoint from the ``StagingArea``, making the ``StagingArea`` empty, and as a result
it will take a datapoint from the ``StagingArea``, making the ``StagingArea`` empty, and as a result
make the training hang.
"""
......
......@@ -137,7 +137,7 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
if len(self.devices) > 1:
assert isinstance(input, FeedfreeInput), input
tower_fn = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn),
tower_fn = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)
grad_list = self._builder.call_for_each_tower(tower_fn)
self.train_op = self._builder.build(grad_list, get_opt_fn)
return []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment