Commit 8abdaf77 authored by Yuxin Wu's avatar Yuxin Wu

bug fix

parent e592271b
...@@ -246,14 +246,16 @@ if __name__ == '__main__': ...@@ -246,14 +246,16 @@ if __name__ == '__main__':
train_tower = range(nr_gpu)[:-nr_gpu//2] or [0] train_tower = range(nr_gpu)[:-nr_gpu//2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format( logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower)))) ','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
trainer = AsyncMultiGPUTrainer
else: else:
logger.warn("Without GPU this model will never learn! CPU is only useful for debug.") logger.warn("Without GPU this model will never learn! CPU is only useful for debug.")
nr_gpu = 0 nr_gpu = 0
PREDICTOR_THREAD = 1 PREDICTOR_THREAD = 1
predict_tower = [0] predict_tower = [0]
train_tower = [0] train_tower = [0]
trainer = QueueInputTrainer
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.tower = train_tower config.tower = train_tower
AsyncMultiGPUTrainer(config, predict_tower=predict_tower).train() trainer(config, predict_tower=predict_tower).train()
...@@ -93,7 +93,7 @@ class Trainer(object): ...@@ -93,7 +93,7 @@ class Trainer(object):
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
suffix = '-summary' # issue#6150 suffix = '-summary' # issue#6150
if val.tag.endswith(suffix): if val.tag.endswith(suffix):
val.tag = va.tag[:-len(suffix)] val.tag = val.tag[:-len(suffix)]
self.stat_holder.add_stat(val.tag, val.simple_value) self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, get_global_step()) self.summary_writer.add_summary(summary, get_global_step())
......
...@@ -56,6 +56,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -56,6 +56,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU." assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
assert tf.test.is_gpu_available()
@staticmethod @staticmethod
...@@ -110,6 +111,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -110,6 +111,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
self._setup_predictor_factory(predict_tower) self._setup_predictor_factory(predict_tower)
self._average_gradient = average_gradient self._average_gradient = average_gradient
assert tf.test.is_gpu_available()
def _setup(self): def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup() super(AsyncMultiGPUTrainer, self)._setup()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment