Commit 3f238a01 authored by Yuxin Wu's avatar Yuxin Wu
parent ebdcf469
...@@ -172,7 +172,7 @@ class SendStat(Triggerable): ...@@ -172,7 +172,7 @@ class SendStat(Triggerable):
stats = [stats] stats = [stats]
self.stats = stats self.stats = stats
def _trigger_epoch(self): def _trigger(self):
holder = self.trainer.stat_holder holder = self.trainer.stat_holder
v = {k: holder.get_stat_now(k) for k in self.stats} v = {k: holder.get_stat_now(k) for k in self.stats}
cmd = self.command.format(**v) cmd = self.command.format(**v)
......
...@@ -12,7 +12,7 @@ from ..utils import logger ...@@ -12,7 +12,7 @@ from ..utils import logger
from .symbolic_functions import rms from .symbolic_functions import rms
from .summary import add_moving_summary from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'GlobalNormClip', 'MapGradient', 'SummaryGradient', 'CheckGradient', __all__ = ['GradientProcessor', 'FilterNoneGrad', 'GlobalNormClip', 'MapGradient', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'apply_grad_processors'] 'ScaleGradient', 'apply_grad_processors']
...@@ -24,12 +24,8 @@ def apply_grad_processors(grads, gradprocs): ...@@ -24,12 +24,8 @@ def apply_grad_processors(grads, gradprocs):
Returns: Returns:
list: list of (grad, var) went through the processors. list: list of (grad, var) went through the processors.
""" """
g = [] gradprocs.insert(0, FilterNoneGrad())
for grad, var in grads: g = grads
if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else:
g.append((grad, var))
for proc in gradprocs: for proc in gradprocs:
g = proc.process(g) g = proc.process(g)
return g return g
...@@ -58,6 +54,21 @@ class GradientProcessor(object): ...@@ -58,6 +54,21 @@ class GradientProcessor(object):
pass pass
class FilterNoneGrad(GradientProcessor):
"""
Skip the update and print a warning (instead of crashing),
when the gradient of certain variable is None.
"""
def _process(self, grads):
g = []
for grad, var in grads:
if grad is None:
logger.warn("No Gradient w.r.t {}".format(var.op.name))
else:
g.append((grad, var))
return g
class GlobalNormClip(GradientProcessor): class GlobalNormClip(GradientProcessor):
""" Clip by global norm. """ Clip by global norm.
The global norm is the sum of norm for **all** gradients. The global norm is the sum of norm for **all** gradients.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment