Commit 8abdaf77 authored by Yuxin Wu's avatar Yuxin Wu

bug fix

parent e592271b
......@@ -246,14 +246,16 @@ if __name__ == '__main__':
train_tower = range(nr_gpu)[:-nr_gpu//2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
trainer = AsyncMultiGPUTrainer
else:
logger.warn("Without GPU this model will never learn! CPU is only useful for debug.")
nr_gpu = 0
PREDICTOR_THREAD = 1
predict_tower = [0]
train_tower = [0]
trainer = QueueInputTrainer
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.tower = train_tower
AsyncMultiGPUTrainer(config, predict_tower=predict_tower).train()
trainer(config, predict_tower=predict_tower).train()
......@@ -93,7 +93,7 @@ class Trainer(object):
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
suffix = '-summary' # issue#6150
if val.tag.endswith(suffix):
val.tag = va.tag[:-len(suffix)]
val.tag = val.tag[:-len(suffix)]
self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, get_global_step())
......
......@@ -56,6 +56,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self._setup_predictor_factory(predict_tower)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
assert tf.test.is_gpu_available()
@staticmethod
......@@ -110,6 +111,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
self._setup_predictor_factory(predict_tower)
self._average_gradient = average_gradient
assert tf.test.is_gpu_available()
def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment