Commit e943200b authored by Yuxin Wu's avatar Yuxin Wu

fix #897

parent 631f011f
...@@ -347,8 +347,8 @@ class AsyncMultiGPUBuilder(DataParallelBuilder): ...@@ -347,8 +347,8 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
""" """
ps_device = 'cpu' if len(self.towers) >= 4 else 'gpu' ps_device = 'cpu' if len(self.towers) >= 4 else 'gpu'
if ps_device == 'gpu':
raw_devices = ['/gpu:{}'.format(k) for k in self.towers] raw_devices = ['/gpu:{}'.format(k) for k in self.towers]
if ps_device == 'gpu':
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
else: else:
devices = [tf.train.replica_device_setter( devices = [tf.train.replica_device_setter(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment