Commit 05ae7b5d authored by Yuxin Wu's avatar Yuxin Wu

gradprocessor reuse name scope

parent e7aaaf13
...@@ -23,6 +23,8 @@ class GradientProcessor(object): ...@@ -23,6 +23,8 @@ class GradientProcessor(object):
Subclass should override the ``_process()`` method. Subclass should override the ``_process()`` method.
""" """
_name_scope = None
def process(self, grads): def process(self, grads):
""" """
Process the symbolic gradients. Process the symbolic gradients.
...@@ -32,8 +34,16 @@ class GradientProcessor(object): ...@@ -32,8 +34,16 @@ class GradientProcessor(object):
Returns: Returns:
list: processed gradients, with the same type as input. list: processed gradients, with the same type as input.
""" """
with tf.name_scope(type(self).__name__):
# reuse the old name_scope, if process() is called multiple times
if self._name_scope is None:
with tf.name_scope(type(self).__name__) as scope:
self._name_scope = scope
return self._process(grads) return self._process(grads)
else:
with tf.name_scope(self._name_scope):
return self._process(grads)
@abstractmethod @abstractmethod
def _process(self, grads): def _process(self, grads):
...@@ -67,6 +77,7 @@ class GlobalNormClip(GradientProcessor): ...@@ -67,6 +77,7 @@ class GlobalNormClip(GradientProcessor):
Args: Args:
global_norm(float): the threshold to clip with. global_norm(float): the threshold to clip with.
""" """
super(GlobalNormClip, self).__init__()
self._norm = float(global_norm) self._norm = float(global_norm)
def _process(self, grads): def _process(self, grads):
...@@ -101,6 +112,7 @@ class MapGradient(GradientProcessor): ...@@ -101,6 +112,7 @@ class MapGradient(GradientProcessor):
if not regex.endswith('$'): if not regex.endswith('$'):
regex = regex + '$' regex = regex + '$'
self.regex = regex self.regex = regex
super(MapGradient, self).__init__()
def _process(self, grads): def _process(self, grads):
ret = [] ret = []
......
...@@ -330,6 +330,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer): ...@@ -330,6 +330,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, SingleCostFeedfreeTrainer):
# Ngpu x 2 # Ngpu x 2
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
with tf.device(v.device): with tf.device(v.device):
# will call apply_gradients (therefore gradproc) multiple times
train_ops.append(opt.apply_gradients( train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(i))) grad_and_vars, name='apply_grad_{}'.format(i)))
self.train_op = tf.group(*train_ops, name='train_op') self.train_op = tf.group(*train_ops, name='train_op')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment