Commit 9505edc6 authored by Yuxin Wu's avatar Yuxin Wu

Improve performance of MapGradient

parent c500fa13
...@@ -128,7 +128,8 @@ class MapGradient(GradientProcessor): ...@@ -128,7 +128,8 @@ class MapGradient(GradientProcessor):
for grad, var in grads: for grad, var in grads:
if re.match(self.regex, var.op.name): if re.match(self.regex, var.op.name):
matched = True matched = True
grad = self.func(grad, var) with tf.device(grad.device):
grad = self.func(grad, var)
if grad is not None: if grad is not None:
ret.append((grad, var)) ret.append((grad, var))
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment