Commit 8df83a93 authored by Yuxin Wu's avatar Yuxin Wu

some change in gradproc, and fix #158

parent 10f55570
......@@ -12,23 +12,9 @@ from ..utils import logger
from .symbolic_functions import rms
from .summary import add_moving_summary
__all__ = ['GradientProcessor', 'FilterNoneGrad', 'GlobalNormClip', 'MapGradient', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'apply_grad_processors']
def apply_grad_processors(grads, gradprocs):
"""
Args:
grads (list): list of (grad, var).
gradprocs (list[GradientProcessor]): gradient processors to apply.
Returns:
list: list of (grad, var) went through the processors.
"""
gradprocs.insert(0, FilterNoneGrad())
g = grads
for proc in gradprocs:
g = proc.process(g)
return g
__all__ = ['GradientProcessor',
'FilterNoneGrad', 'GlobalNormClip', 'MapGradient', 'SummaryGradient',
'CheckGradient', 'ScaleGradient']
@six.add_metaclass(ABCMeta)
......@@ -118,13 +104,17 @@ class MapGradient(GradientProcessor):
def _process(self, grads):
ret = []
matched = False
for grad, var in grads:
if re.match(self.regex, var.op.name):
matched = True
grad = self.func(grad, var)
if grad is not None:
ret.append((grad, var))
else:
ret.append((grad, var))
if not matched:
logger.warn("[MapGradient] No match was found for regex {}.".format(self.regex))
return ret
......
......@@ -5,7 +5,6 @@
import tensorflow as tf
from contextlib import contextmanager
from .gradproc import apply_grad_processors as apply_gradproc
from .gradproc import FilterNoneGrad
__all__ = ['apply_grad_processors', 'ProxyOptimizer',
......@@ -48,13 +47,19 @@ def apply_grad_processors(opt, gradprocs):
class _ApplyGradientProcessor(ProxyOptimizer):
def __init__(self, opt, gradprocs):
self._gradprocs = gradprocs
self._gradprocs = [FilterNoneGrad()] + gradprocs
super(_ApplyGradientProcessor, self).__init__(opt)
def apply_gradients(self, grads_and_vars,
global_step=None, name=None):
g = apply_gradproc(grads_and_vars, self._gradprocs)
g = self._apply(grads_and_vars)
return self._opt.apply_gradients(g, global_step, name)
def _apply(self, g):
for proc in self._gradprocs:
g = proc.process(g)
return g
return _ApplyGradientProcessor(opt, gradprocs)
......
......@@ -13,7 +13,7 @@ from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread
from ..tfutils.tower import TowerContext
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from ..tfutils.gradproc import FilterNoneGrad, ScaleGradient
from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer
......@@ -190,11 +190,12 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
super(AsyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
grad_list = FilterNoneGrad().process(grad_list)
if self._scale_gradient and self.config.nr_tower > 1:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradproc = ScaleGradient(('.*', 1.0 / self.config.nr_tower), log=False)
grad_list = apply_grad_processors(grad_list, [gradproc])
grad_list = gradproc.process(grad_list)
# use grad from the first tower for iteration in main thread
self.train_op = self.config.optimizer.apply_gradients(grad_list[0], name='min_op')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment