Commit ee227da4 authored by Yuxin Wu's avatar Yuxin Wu

refactor gradproc. fix performance problem of checkgradient

parent 37e664f8
...@@ -72,7 +72,7 @@ class ModelDesc(object): ...@@ -72,7 +72,7 @@ class ModelDesc(object):
def get_gradient_processor(self): def get_gradient_processor(self):
""" Return a list of GradientProcessor. They will be executed in order""" """ Return a list of GradientProcessor. They will be executed in order"""
return [#SummaryGradient(), return [#SummaryGradient(),
#CheckGradient() CheckGradient()
] ]
......
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import re import re
from ..utils import logger from ..utils import logger
from .symbolic_functions import rms
from .summary import add_moving_summary from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient', __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
...@@ -29,40 +30,64 @@ class GradientProcessor(object): ...@@ -29,40 +30,64 @@ class GradientProcessor(object):
def _process(self, grads): def _process(self, grads):
pass pass
class MapGradient(GradientProcessor):
"""
Apply a function on all gradient if the name matches regex.
Keep the other gradients unchanged.
"""
def __init__(self, func, regex='.*'):
"""
:param func: takes a (grad, var) pair and returns a grad. If return None, the
gradient is discarded.
:param regex: used to match variables. default to match all variables.
"""
self.func = func
if not regex.endswith('$'):
regex = regex + '$'
self.regex = regex
def _process(self, grads):
ret = []
for grad, var in grads:
if re.match(self.regex, var.op.name):
grad = self.func(grad, var)
if grad is not None:
ret.append((grad, var))
else:
ret.append((grad, var))
return ret
_summaried_gradient = set() _summaried_gradient = set()
class SummaryGradient(GradientProcessor): class SummaryGradient(MapGradient):
""" """
Summary history and RMS for each graident variable Summary history and RMS for each graident variable
""" """
def _process(self, grads): def __init__(self):
for grad, var in grads: super(SummaryGradient, self).__init__(self._mapper)
def _mapper(self, grad, var):
name = var.op.name name = var.op.name
if name in _summaried_gradient: if name not in _summaried_gradient:
continue
_summaried_gradient.add(name) _summaried_gradient.add(name)
tf.histogram_summary(name + '/grad', grad) tf.histogram_summary(name + '/grad', grad)
add_moving_summary(tf.sqrt( add_moving_summary(rms(grad, name=name + '/rms'))
tf.reduce_mean(tf.square(grad)), return grad
name=name + '/RMS'))
return grads
class CheckGradient(GradientProcessor): class CheckGradient(MapGradient):
""" """
Check for numeric issue Check for numeric issue.
""" """
def _process(self, grads): def __init__(self):
ret = [] super(CheckGradient, self).__init__(self._mapper)
for grad, var in grads:
op = tf.Assert(tf.reduce_all(tf.is_finite(var)), def _mapper(self, grad, var):
[var], summarize=100) # this is very slow...
with tf.control_dependencies([op]): #op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
grad = tf.identity(grad) grad = tf.check_numerics(grad, 'CheckGradient')
ret.append((grad, var)) return grad
return ret
class ScaleGradient(GradientProcessor): class ScaleGradient(MapGradient):
""" """
Scale gradient by a multiplier Scale gradient by a multiplier
""" """
...@@ -71,10 +96,9 @@ class ScaleGradient(GradientProcessor): ...@@ -71,10 +96,9 @@ class ScaleGradient(GradientProcessor):
:param multipliers: list of (regex, float) :param multipliers: list of (regex, float)
""" """
self.multipliers = multipliers self.multipliers = multipliers
super(ScaleGradient, self).__init__(self._mapper)
def _process(self, grads): def _mapper(self, grad, var):
ret = []
for grad, var in grads:
varname = var.op.name varname = var.op.name
for regex, val in self.multipliers: for regex, val in self.multipliers:
# always match against the whole name # always match against the whole name
...@@ -84,31 +108,7 @@ class ScaleGradient(GradientProcessor): ...@@ -84,31 +108,7 @@ class ScaleGradient(GradientProcessor):
if re.match(regex, varname): if re.match(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname)) logger.info("Apply lr multiplier {} for {}".format(val, varname))
if val != 0: # skip zero to speed up if val != 0: # skip zero to speed up
ret.append((grad * val, var)) return grad * val
break
else:
ret.append((grad, var))
return ret
class MapGradient(GradientProcessor):
"""
Apply a function on all gradient if the name matches regex.
"""
def __init__(self, func, regex='.*'):
"""
:param func: takes a tensor and returns a tensor
:param regex: used to match variables. default to match all variables.
"""
self.func = func
if not regex.endswith('$'):
regex = regex + '$'
self.regex = regex
def _process(self, grads):
ret = []
for grad, var in grads:
if re.match(self.regex, var.op.name):
ret.append((self.func(grad), var))
else: else:
ret.append((grad, var)) return None
return ret return grad
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment