Commit 0d894877 authored by Yuxin Wu's avatar Yuxin Wu

gradproc

parent 80622ae7
......@@ -14,6 +14,7 @@ from tensorpack.train import TrainConfig, SimpleTrainer
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.gradproc import *
from tensorpack.utils.summary import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
......
......@@ -6,7 +6,7 @@
import tensorflow as tf
import itertools
from tqdm import tqdm
from abc import ABCMeta
from abc import ABCMeta, abstractmethod
from ..utils import *
from ..utils.stat import *
......
......@@ -7,6 +7,8 @@ from abc import ABCMeta, abstractmethod
import tensorflow as tf
from collections import namedtuple
from ..utils.gradproc import *
__all__ = ['ModelDesc', 'InputVar']
InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
......@@ -68,8 +70,6 @@ class ModelDesc(object):
the cost to minimize. scalar variable
"""
def get_lr_multiplier(self):
"""
Return a list of (variable_regex: multiplier)
"""
return []
def get_gradient_processor(self):
""" Return a list of GradientProcessor. They will be executed in order"""
return [SummaryGradient(), CheckGradient()]
......@@ -111,3 +111,8 @@ class Trainer(object):
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
def process_grads(self, grads):
procs = self.config.model.get_gradient_processor()
for proc in procs:
grads = proc.process(grads)
return grads
......@@ -15,17 +15,6 @@ from ..utils.summary import summary_moving_average
__all__ = ['SimpleTrainer', 'QueueInputTrainer', 'start_train']
def summary_grads(grads):
for grad, var in grads:
if grad:
# TODO also summary RMS and print
tf.histogram_summary(var.op.name + '/gradients', grad)
def check_grads(grads):
for grad, var in grads:
assert grad is not None, "Grad is None for variable {}".format(var.name)
tf.Assert(tf.reduce_all(tf.is_finite(var)), [var])
def scale_grads(grads, multiplier):
ret = []
for grad, var in grads:
......@@ -54,9 +43,7 @@ class SimpleTrainer(Trainer):
avg_maintain_op = summary_moving_average(cost_var)
grads = self.config.optimizer.compute_gradients(cost_var)
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads)
grads = self.process_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
......@@ -133,9 +120,7 @@ class QueueInputTrainer(Trainer):
grads = self.config.optimizer.compute_gradients(cost_var)
avg_maintain_op = summary_moving_average(cost_var) # TODO(multigpu) average the cost from each device?
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads)
grads = self.process_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: gradproc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from abc import ABCMeta, abstractmethod
import re
from . import logger
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient']
class GradientProcessor(object):
__metaclass__ = ABCMeta
@abstractmethod
def process(self, grads):
"""
Process the symbolic gradients, return symbolic gradients
grads: list of (grad, var)
"""
class SummaryGradient(GradientProcessor):
"""
Summary history and RMS for each graident variable
"""
def process(self, grads):
for grad, var in grads:
tf.histogram_summary(var.op.name + '/grad', grad)
tf.scalar_summary(var.op.name + '/gradRMS',
tf.sqrt(tf.reduce_mean(tf.square(grad))))
return grads
class CheckGradient(GradientProcessor):
"""
Check for numeric issue
"""
def process(self, grads):
for grad, var in grads:
assert grad is not None, "Grad is None for variable {}".format(var.name)
# TODO make assert work
tf.Assert(tf.reduce_all(tf.is_finite(var)), [var])
return grads
class ScaleGradient(GradientProcessor):
"""
Scale gradient by a multiplier
"""
def __init__(self, multipliers):
"""
multipliers: list of (regex, float)
"""
self.multipliers = multipliers
def process(self, grads):
# TODO use None for zero to speed up?
ret = []
for grad, var in grads:
varname = var.op.name
for regex, val in self.multipliers:
if re.search(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname))
ret.append((grad * val, var))
break
else:
ret.append((grad, var))
return ret
......@@ -39,6 +39,9 @@ class SaverRestore(SessionInit):
self.path = model_path
class ParamRestore(SessionInit):
"""
Restore trainable variables from a dictionary
"""
def __init__(self, param_dict):
self.prms = param_dict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment