Commit 9505edc6 authored by Yuxin Wu's avatar Yuxin Wu

Improve performance of MapGradient

parent c500fa13
......@@ -128,6 +128,7 @@ class MapGradient(GradientProcessor):
for grad, var in grads:
if re.match(self.regex, var.op.name):
matched = True
with tf.device(grad.device):
grad = self.func(grad, var)
if grad is not None:
ret.append((grad, var))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment