Commit ee227da4 authored by Yuxin Wu's avatar Yuxin Wu

refactor gradproc. fix performance problem of checkgradient

parent 37e664f8
......@@ -72,7 +72,7 @@ class ModelDesc(object):
def get_gradient_processor(self):
""" Return a list of GradientProcessor. They will be executed in order"""
return [#SummaryGradient(),
#CheckGradient()
CheckGradient()
]
......
......@@ -7,6 +7,7 @@ import tensorflow as tf
from abc import ABCMeta, abstractmethod
import re
from ..utils import logger
from .symbolic_functions import rms
from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
......@@ -29,40 +30,64 @@ class GradientProcessor(object):
def _process(self, grads):
pass
class MapGradient(GradientProcessor):
"""
Apply a function on all gradient if the name matches regex.
Keep the other gradients unchanged.
"""
def __init__(self, func, regex='.*'):
"""
:param func: takes a (grad, var) pair and returns a grad. If return None, the
gradient is discarded.
:param regex: used to match variables. default to match all variables.
"""
self.func = func
if not regex.endswith('$'):
regex = regex + '$'
self.regex = regex
def _process(self, grads):
ret = []
for grad, var in grads:
if re.match(self.regex, var.op.name):
grad = self.func(grad, var)
if grad is not None:
ret.append((grad, var))
else:
ret.append((grad, var))
return ret
_summaried_gradient = set()
class SummaryGradient(GradientProcessor):
class SummaryGradient(MapGradient):
"""
Summary history and RMS for each graident variable
"""
def _process(self, grads):
for grad, var in grads:
name = var.op.name
if name in _summaried_gradient:
continue
def __init__(self):
super(SummaryGradient, self).__init__(self._mapper)
def _mapper(self, grad, var):
name = var.op.name
if name not in _summaried_gradient:
_summaried_gradient.add(name)
tf.histogram_summary(name + '/grad', grad)
add_moving_summary(tf.sqrt(
tf.reduce_mean(tf.square(grad)),
name=name + '/RMS'))
return grads
add_moving_summary(rms(grad, name=name + '/rms'))
return grad
class CheckGradient(GradientProcessor):
class CheckGradient(MapGradient):
"""
Check for numeric issue
Check for numeric issue.
"""
def _process(self, grads):
ret = []
for grad, var in grads:
op = tf.Assert(tf.reduce_all(tf.is_finite(var)),
[var], summarize=100)
with tf.control_dependencies([op]):
grad = tf.identity(grad)
ret.append((grad, var))
return ret
def __init__(self):
super(CheckGradient, self).__init__(self._mapper)
class ScaleGradient(GradientProcessor):
def _mapper(self, grad, var):
# this is very slow...
#op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100)
grad = tf.check_numerics(grad, 'CheckGradient')
return grad
class ScaleGradient(MapGradient):
"""
Scale gradient by a multiplier
"""
......@@ -71,44 +96,19 @@ class ScaleGradient(GradientProcessor):
:param multipliers: list of (regex, float)
"""
self.multipliers = multipliers
super(ScaleGradient, self).__init__(self._mapper)
def _process(self, grads):
ret = []
for grad, var in grads:
varname = var.op.name
for regex, val in self.multipliers:
# always match against the whole name
if not regex.endswith('$'):
regex = regex + '$'
if re.match(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname))
if val != 0: # skip zero to speed up
ret.append((grad * val, var))
break
else:
ret.append((grad, var))
return ret
def _mapper(self, grad, var):
varname = var.op.name
for regex, val in self.multipliers:
# always match against the whole name
if not regex.endswith('$'):
regex = regex + '$'
class MapGradient(GradientProcessor):
"""
Apply a function on all gradient if the name matches regex.
"""
def __init__(self, func, regex='.*'):
"""
:param func: takes a tensor and returns a tensor
:param regex: used to match variables. default to match all variables.
"""
self.func = func
if not regex.endswith('$'):
regex = regex + '$'
self.regex = regex
def _process(self, grads):
ret = []
for grad, var in grads:
if re.match(self.regex, var.op.name):
ret.append((self.func(grad), var))
else:
ret.append((grad, var))
return ret
if re.match(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname))
if val != 0: # skip zero to speed up
return grad * val
else:
return None
return grad
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment