Commit 05ae7b5d authored by Yuxin Wu's avatar Yuxin Wu

gradprocessor reuse name scope

parent e7aaaf13
......@@ -23,6 +23,8 @@ class GradientProcessor(object):
Subclass should override the ``_process()`` method.
"""
_name_scope = None
def process(self, grads):
"""
Process the symbolic gradients.
......@@ -32,8 +34,16 @@ class GradientProcessor(object):
Returns:
list: processed gradients, with the same type as input.
"""
with tf.name_scope(type(self).__name__):
# reuse the old name_scope, if process() is called multiple times
if self._name_scope is None:
with tf.name_scope(type(self).__name__) as scope:
self._name_scope = scope
return self._process(grads)
else:
with tf.name_scope(self._name_scope):
return self._process(grads)
@abstractmethod
def _process(self, grads):
......@@ -67,6 +77,7 @@ class GlobalNormClip(GradientProcessor):
Args:
global_norm(float): the threshold to clip with.
"""
super(GlobalNormClip, self).__init__()
self._norm = float(global_norm)
def _process(self, grads):
......@@ -101,6 +112,7 @@ class MapGradient(GradientProcessor):
if not regex.endswith('$'):
regex = regex + '$'
self.regex = regex
super(MapGradient, self).__init__()
def _process(self, grads):
ret = []
......
......@@ -330,6 +330,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
# Ngpu x 2
v = grad_and_vars[0][1]
with tf.device(v.device):
# will call apply_gradients (therefore gradproc) multiple times
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(i)))
self.train_op = tf.group(*train_ops, name='train_op')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment