Commit 6b85a1f1 authored by Yuxin Wu's avatar Yuxin Wu

global norm clip

parent 68cb6994
...@@ -90,8 +90,7 @@ class Model(ModelDesc): ...@@ -90,8 +90,7 @@ class Model(ModelDesc):
summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
def get_gradient_processor(self): def get_gradient_processor(self):
return [MapGradient(lambda grad: tf.clip_by_global_norm( return [GlobalNormClip(5)]
[grad], param.grad_clip)[0][0])]
def get_config(): def get_config():
logger.auto_set_dir() logger.auto_set_dir()
......
...@@ -12,7 +12,8 @@ from .symbolic_functions import rms ...@@ -12,7 +12,8 @@ from .symbolic_functions import rms
from .summary import add_moving_summary from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient', __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'MapGradient', 'apply_grad_processors'] 'ScaleGradient', 'MapGradient', 'apply_grad_processors',
'GlobalNormClip']
def apply_grad_processors(grads, gradprocs): def apply_grad_processors(grads, gradprocs):
""" """
...@@ -47,6 +48,20 @@ class GradientProcessor(object): ...@@ -47,6 +48,20 @@ class GradientProcessor(object):
def _process(self, grads): def _process(self, grads):
pass pass
class GlobalNormClip(GradientProcessor):
def __init__(self, global_norm):
""" Clip by global norm
Note that the global norm is the sum of norm for **all** gradients
"""
self._norm = global_norm
def _process(self, grads):
g = [k[0] for k in grads]
v = [k[1] for k in grads]
g, _ = tf.clip_by_global_norm(g, self._norm, name='clip_by_global_norm')
return list(zip(g, v))
class MapGradient(GradientProcessor): class MapGradient(GradientProcessor):
""" """
Apply a function on all gradient if the name matches regex. Apply a function on all gradient if the name matches regex.
......
...@@ -98,6 +98,7 @@ class Trainer(object): ...@@ -98,6 +98,7 @@ class Trainer(object):
def setup(self): def setup(self):
self._setup() self._setup()
describe_model() describe_model()
get_global_step_var()
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
......
...@@ -126,7 +126,7 @@ class FeedlessTrainer(Trainer): ...@@ -126,7 +126,7 @@ class FeedlessTrainer(Trainer):
Always return new tensors (for multi tower) if called mutliple times. Always return new tensors (for multi tower) if called mutliple times.
""" """
class SingleCostFeedlessTrainer(Trainer): class SingleCostFeedlessTrainer(FeedlessTrainer):
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower""" """ get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors_noreuse() actual_inputs = self._get_input_tensors_noreuse()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment