Commit ee227da4 authored by Yuxin Wu's avatar Yuxin Wu

refactor gradproc. fix performance problem of checkgradient

parent 37e664f8
...@@ -72,7 +72,7 @@ class ModelDesc(object): ...@@ -72,7 +72,7 @@ class ModelDesc(object):
def get_gradient_processor(self): def get_gradient_processor(self):
""" Return a list of GradientProcessor. They will be executed in order""" """ Return a list of GradientProcessor. They will be executed in order"""
return [#SummaryGradient(), return [#SummaryGradient(),
#CheckGradient() CheckGradient()
] ]
......
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import re import re
from ..utils import logger from ..utils import logger
from .symbolic_functions import rms
from .summary import add_moving_summary from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient', __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
...@@ -29,40 +30,64 @@ class GradientProcessor(object): ...@@ -29,40 +30,64 @@ class GradientProcessor(object):
def _process(self, grads): def _process(self, grads):
pass pass
class MapGradient(GradientProcessor):
"""
Apply a function on all gradient if the name matches regex.
Keep the other gradients unchanged.
"""
def __init__(self, func, regex='.*'):
"""
:param func: takes a (grad, var) pair and returns a grad. If return None, the
gradient is discarded.
:param regex: used to match variables. default to match all variables.
"""
self.func = func
if not regex.endswith('$'):
regex = regex + '$'
self.regex = regex
def _process(self, grads):
ret = []
for grad, var in grads:
if re.match(self.regex, var.op.name):
grad = self.func(grad, var)
if grad is not None:
ret.append((grad, var))
else:
ret.append((grad, var))
return ret
_summaried_gradient = set() _summaried_gradient = set()
class SummaryGradient(GradientProcessor): class SummaryGradient(MapGradient):
""" """
Summary history and RMS for each graident variable Summary history and RMS for each graident variable
""" """
def _process(self, grads): def __init__(self):
for grad, var in grads: super(SummaryGradient, self).__init__(self._mapper)
name = var.op.name
if name in _summaried_gradient: def _mapper(self, grad, var):
continue name = var.op.name
if name not in _summaried_gradient:
_summaried_gradient.add(name) _summaried_gradient.add(name)
tf.histogram_summary(name + '/grad', grad) tf.histogram_summary(name + '/grad', grad)
add_moving_summary(tf.sqrt( add_moving_summary(rms(grad, name=name + '/rms'))
tf.reduce_mean(tf.square(grad)), return grad
name=name + '/RMS'))
return grads
class CheckGradient(GradientProcessor): class CheckGradient(MapGradient):
""" """
Check for numeric issue Check for numeric issue.
""" """
def _process(self, grads): def __init__(self):
ret = [] super(CheckGradient, self).__init__(self._mapper)
for grad, var in grads:
op = tf.Assert(tf.reduce_all(tf.is_finite(var)),
[var], summarize=100)
with tf.control_dependencies([op]):
grad = tf.identity(grad)
ret.append((grad, var))
return ret
class ScaleGradient(GradientProcessor): def _mapper(self, grad, var):
# this is very slow...
#op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
grad = tf.check_numerics(grad, 'CheckGradient')
return grad
class ScaleGradient(MapGradient):
""" """
Scale gradient by a multiplier Scale gradient by a multiplier
""" """
...@@ -71,44 +96,19 @@ class ScaleGradient(GradientProcessor): ...@@ -71,44 +96,19 @@ class ScaleGradient(GradientProcessor):
:param multipliers: list of (regex, float) :param multipliers: list of (regex, float)
""" """
self.multipliers = multipliers self.multipliers = multipliers
super(ScaleGradient, self).__init__(self._mapper)
def _process(self, grads): def _mapper(self, grad, var):
ret = [] varname = var.op.name
for grad, var in grads: for regex, val in self.multipliers:
varname = var.op.name # always match against the whole name
for regex, val in self.multipliers: if not regex.endswith('$'):
# always match against the whole name regex = regex + '$'
if not regex.endswith('$'):
regex = regex + '$'
if re.match(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname))
if val != 0: # skip zero to speed up
ret.append((grad * val, var))
break
else:
ret.append((grad, var))
return ret
class MapGradient(GradientProcessor): if re.match(regex, varname):
""" logger.info("Apply lr multiplier {} for {}".format(val, varname))
Apply a function on all gradient if the name matches regex. if val != 0: # skip zero to speed up
""" return grad * val
def __init__(self, func, regex='.*'): else:
""" return None
:param func: takes a tensor and returns a tensor return grad
:param regex: used to match variables. default to match all variables.
"""
self.func = func
if not regex.endswith('$'):
regex = regex + '$'
self.regex = regex
def _process(self, grads):
ret = []
for grad, var in grads:
if re.match(self.regex, var.op.name):
ret.append((self.func(grad), var))
else:
ret.append((grad, var))
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment