Commit e943200b authored by Yuxin Wu's avatar Yuxin Wu

fix #897

parent 631f011f
......@@ -347,8 +347,8 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
"""
ps_device = 'cpu' if len(self.towers) >= 4 else 'gpu'
if ps_device == 'gpu':
raw_devices = ['/gpu:{}'.format(k) for k in self.towers]
if ps_device == 'gpu':
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
else:
devices = [tf.train.replica_device_setter(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment