Commit 44f603c0 authored by Yuxin Wu's avatar Yuxin Wu

work around device name problem with colocation (#329)

parent 61c113b8
......@@ -131,10 +131,13 @@ class LeastLoadedDeviceSetter(object):
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return self.worker_device
return sanitize_name(self.worker_device)
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
......@@ -142,7 +145,7 @@ class LeastLoadedDeviceSetter(object):
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return device_name
return sanitize_name(device_name)
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment