Commit 6b85a1f1 authored by Yuxin Wu's avatar Yuxin Wu

global norm clip

parent 68cb6994
......@@ -90,8 +90,7 @@ class Model(ModelDesc):
summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
def get_gradient_processor(self):
return [MapGradient(lambda grad: tf.clip_by_global_norm(
[grad], param.grad_clip)[0][0])]
return [GlobalNormClip(5)]
def get_config():
logger.auto_set_dir()
......
......@@ -12,7 +12,8 @@ from .symbolic_functions import rms
from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'MapGradient', 'apply_grad_processors']
'ScaleGradient', 'MapGradient', 'apply_grad_processors',
'GlobalNormClip']
def apply_grad_processors(grads, gradprocs):
"""
......@@ -47,6 +48,20 @@ class GradientProcessor(object):
def _process(self, grads):
pass
class GlobalNormClip(GradientProcessor):
def __init__(self, global_norm):
""" Clip by global norm
Note that the global norm is the sum of norm for **all** gradients
"""
self._norm = global_norm
def _process(self, grads):
g = [k[0] for k in grads]
v = [k[1] for k in grads]
g, _ = tf.clip_by_global_norm(g, self._norm, name='clip_by_global_norm')
return list(zip(g, v))
class MapGradient(GradientProcessor):
"""
Apply a function on all gradient if the name matches regex.
......
......@@ -98,6 +98,7 @@ class Trainer(object):
def setup(self):
self._setup()
describe_model()
get_global_step_var()
# some final operations that might modify the graph
logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
......
......@@ -126,7 +126,7 @@ class FeedlessTrainer(Trainer):
Always return new tensors (for multi tower) if called mutliple times.
"""
class SingleCostFeedlessTrainer(Trainer):
class SingleCostFeedlessTrainer(FeedlessTrainer):
def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors_noreuse()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment