Commit 3f238a01 authored by Yuxin Wu's avatar Yuxin Wu
parent ebdcf469
......@@ -172,7 +172,7 @@ class SendStat(Triggerable):
stats = [stats]
self.stats = stats
def _trigger_epoch(self):
def _trigger(self):
holder = self.trainer.stat_holder
v = {k: holder.get_stat_now(k) for k in self.stats}
cmd = self.command.format(**v)
......
......@@ -12,7 +12,7 @@ from ..utils import logger
from .symbolic_functions import rms
from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'GlobalNormClip', 'MapGradient', 'SummaryGradient', 'CheckGradient',
__all__ = ['GradientProcessor', 'FilterNoneGrad', 'GlobalNormClip', 'MapGradient', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'apply_grad_processors']
......@@ -24,12 +24,8 @@ def apply_grad_processors(grads, gradprocs):
Returns:
list: list of (grad, var) went through the processors.
"""
g = []
for grad, var in grads:
if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else:
g.append((grad, var))
gradprocs.insert(0, FilterNoneGrad())
g = grads
for proc in gradprocs:
g = proc.process(g)
return g
......@@ -58,6 +54,21 @@ class GradientProcessor(object):
pass
class FilterNoneGrad(GradientProcessor):
"""
Skip the update and print a warning (instead of crashing),
when the gradient of certain variable is None.
"""
def _process(self, grads):
g = []
for grad, var in grads:
if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else:
g.append((grad, var))
return g
class GlobalNormClip(GradientProcessor):
""" Clip by global norm.
The global norm is the sum of norm for **all** gradients.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment