Commit f60989d3 authored by Yuxin Wu's avatar Yuxin Wu

fix bn and rewrite saverrestore with var.load

parent 3f238a01
...@@ -118,7 +118,7 @@ def get_bn_variables(x, use_scale, use_bias): ...@@ -118,7 +118,7 @@ def get_bn_variables(x, use_scale, use_bias):
initializer=tf.constant_initializer(), trainable=False) initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out], moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False) initializer=tf.constant_initializer(), trainable=False)
return beta, gamma, moving_mean, moving_var return x, beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay): def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
...@@ -171,7 +171,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -171,7 +171,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
with the official inceptionv3 example). with the official inceptionv3 example).
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias) x, beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias)
ctx = get_current_tower_context() ctx = get_current_tower_context()
if use_local_stat is None: if use_local_stat is None:
...@@ -231,7 +231,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -231,7 +231,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias) x, beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias)
ctx = get_current_tower_context() ctx = get_current_tower_context()
use_local_stat = ctx.is_training use_local_stat = ctx.is_training
......
...@@ -60,8 +60,8 @@ def get_global_step_var(): ...@@ -60,8 +60,8 @@ def get_global_step_var():
with tf.variable_scope(scope, reuse=False), \ with tf.variable_scope(scope, reuse=False), \
tf.name_scope(None): tf.name_scope(None):
var = tf.get_variable(GLOBAL_STEP_OP_NAME, var = tf.get_variable(GLOBAL_STEP_OP_NAME,
initializer=0, initializer=tf.constant(0, dtype=tf.int64),
trainable=False, dtype=tf.int32) trainable=False, dtype=tf.int64)
return var return var
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
from .gradproc import apply_grad_processors as apply_gradproc from .gradproc import apply_grad_processors as apply_gradproc
from .gradproc import FilterNoneGrad
__all__ = ['apply_grad_processors', 'ProxyOptimizer', __all__ = ['apply_grad_processors', 'ProxyOptimizer',
'PostProcessOptimizer', 'VariableAssignmentOptimizer'] 'PostProcessOptimizer', 'VariableAssignmentOptimizer']
...@@ -115,3 +116,59 @@ class VariableAssignmentOptimizer(PostProcessOptimizer): ...@@ -115,3 +116,59 @@ class VariableAssignmentOptimizer(PostProcessOptimizer):
return t return t
return tf.assign(v, t, use_locking=False).op return tf.assign(v, t, use_locking=False).op
super(VariableAssignmentOptimizer, self).__init__(opt, f) super(VariableAssignmentOptimizer, self).__init__(opt, f)
class AccumGradOptimizer(ProxyOptimizer):
def __init__(self, opt, niter):
super(AccumGradOptimizer, self).__init__(opt)
self._niter = niter
self._name = "AccumGrad"
self._counter = None
def _create_accum_slots(self, var_list):
slots = []
for v in var_list:
s = self._zeros_slot(v, "accum", self._name)
slots.append(s)
return slots
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
assert global_step is None, \
"AccumGradOptimizer doesn't support the option global_step! " \
"Please maintain it yourself."
grads_and_vars = FilterNoneGrad().process(grads_and_vars)
vs = []
for g, v in grads_and_vars:
assert isinstance(g, tf.Tensor) and isinstance(v, tf.Variable), \
"AccumGradOptimizer only works for dense update! " \
"Types of v and g are {} and {}".format(type(v), type(g))
vs.append(v)
with tf.control_dependencies(None):
slots = self._create_accum_slots(vs)
slots_and_vars = [(s, gv[1]) for s, gv in zip(slots, grads_and_vars)]
# Create the counter on the same device as the first variable.
with tf.variable_scope(self._name), \
tf.colocate_with(vs[0]):
counter = tf.Variable(
0, name="counter", trainable=False, dtype=tf.int32)
ops = []
for s, gv in zip(slots, grads_and_vars):
g, v = gv
ops.append(s.assign_add(s, g))
update_counter = tf.assign_add(counter, 1, name='update_counter')
update_slot_op = tf.group(update_counter, *ops, name='update_slot')
def update_grad():
update_op = self._opt.apply_gradients(slots_and_vars)
with tf.control_dependencies([update_op]):
clear_ops = [tf.assign(s, 0.0) for s in slots]
return tf.group(*clear_ops, name='update_grad')
pred = tf.equal(tf.mod(counter, self._niter), 0)
with tf.control_dependencies([update_slot_op]):
if name is None:
name = 'cond_update_grad'
op = tf.cond(pred, update_grad, lambda: tf.no_op(), name=name)
return op
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import os import os
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import six import six
...@@ -57,7 +56,7 @@ class NewSession(SessionInit): ...@@ -57,7 +56,7 @@ class NewSession(SessionInit):
class SaverRestore(SessionInit): class SaverRestore(SessionInit):
""" """
Restore an old model saved by :class:`ModelSaver`. Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
""" """
def __init__(self, model_path, prefix=None): def __init__(self, model_path, prefix=None):
...@@ -73,28 +72,26 @@ class SaverRestore(SessionInit): ...@@ -73,28 +72,26 @@ class SaverRestore(SessionInit):
def _init(self, sess): def _init(self, sess):
logger.info( logger.info(
"Restoring checkpoint from {} ...".format(self.path)) "Restoring checkpoint from {} ...".format(self.path))
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path) reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = self._get_vars_to_restore_multimap(chkpt_vars) graph_vars = tf.global_variables()
for dic in SaverRestore._produce_restore_dict(vars_map): chkpt_vars_used = set()
# multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2)
saver.restore(sess, self.path)
@staticmethod with sess.as_default():
def _produce_restore_dict(vars_multimap): for v in graph_vars:
""" name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
Produce {var_name: var} dict that can be used by `tf.train.Saver`, from a {var_name: [vars]} dict. if name in chkpt_vars:
""" val = reader.get_tensor(name)
while len(vars_multimap): SessionUpdate.load_value_to_var(v, val)
ret = {} chkpt_vars_used.add(name)
for k in list(vars_multimap.keys()): else:
v = vars_multimap[k] vname = v.op.name
ret[k] = v[-1] if not is_training_name(vname):
del v[-1] logger.warn("Variable {} in the graph not found in checkpoint!".format(vname))
if not len(v): if len(chkpt_vars_used) < len(chkpt_vars):
del vars_multimap[k] unused = chkpt_vars - chkpt_vars_used
yield ret for name in sorted(unused):
if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
@staticmethod @staticmethod
def _read_checkpoint_vars(model_path): def _read_checkpoint_vars(model_path):
...@@ -105,37 +102,7 @@ class SaverRestore(SessionInit): ...@@ -105,37 +102,7 @@ class SaverRestore(SessionInit):
if v.startswith(PREDICT_TOWER): if v.startswith(PREDICT_TOWER):
logger.error("Found {} in checkpoint. " logger.error("Found {} in checkpoint. "
"But anything from prediction tower shouldn't be saved.".format(v.name)) "But anything from prediction tower shouldn't be saved.".format(v.name))
return set(ckpt_vars) return reader, set(ckpt_vars)
def _get_vars_to_restore_multimap(self, vars_available):
"""
:param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
"""
vars_to_restore = tf.global_variables()
var_dict = defaultdict(list)
chkpt_vars_used = set()
for v in vars_to_restore:
name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
# try to load both 'varname' and 'opname' from checkpoint
# because some old checkpoint might not have ':0'
if name in vars_available:
var_dict[name].append(v)
chkpt_vars_used.add(name)
elif name.endswith(':0'):
name = name[:-2]
if name in vars_available:
var_dict[name].append(v)
chkpt_vars_used.add(name)
else:
if not is_training_name(v.op.name):
logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used
for name in sorted(unused):
if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
return var_dict
class ParamRestore(SessionInit): class ParamRestore(SessionInit):
......
...@@ -59,23 +59,64 @@ class SessionUpdate(object): ...@@ -59,23 +59,64 @@ class SessionUpdate(object):
savename = get_savename_from_varname(v.name) savename = get_savename_from_varname(v.name)
self.name_map[savename].append(v) self.name_map[savename].append(v)
@staticmethod
def load_value_to_var(var, val, strict=False):
"""
Call `var.load(val)` with the default session.
Args:
var (tf.Variable):
strict (bool): Behave less strict if set to False.
"""
if strict:
var.load(val)
return
name = var.op.name
# check incompatible shape
varshape = tuple(var.get_shape().as_list())
if varshape != val.shape:
# TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(val.shape), \
"{}: {}!={}".format(name, varshape, val.shape)
logger.warn("Variable {} is reshaped during assigning".format(name))
val = val.reshape(varshape)
# fix some common type incompatibility problem, but is certainly not enough
def upcast(vartype, valtype):
# allow up-casting
if vartype == tf.float64 and valtype == np.float32:
return np.float64
if vartype in [tf.int64, tf.int32] and valtype in [np.int32, np.int16, np.int8]:
return np.int64 if vartype == tf.int64 else np.int32
return None
if hasattr(val, 'dtype'):
vartype = var.value().dtype
if vartype != val.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, val.dtype)
newtype = upcast(var.dtype, val.dtype)
if newtype is not None:
val = newtype(val)
logger.warn(msg + " Load it after casting!")
else:
assert vartype == val.dtype, msg
try:
var.load(val)
except tf.errors.InvalidArgumentError:
logger.exc("Cannot load this value to the variable {}".format(name))
def update(self, prms): def update(self, prms):
""" """
Args: Args:
prms(dict): dict of {variable name: value} prms(dict): dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update. Any name in prms must be in the graph and in vars_to_update.
""" """
for name, value in six.iteritems(prms): with self.sess.as_default():
assert name in self.name_map for name, value in six.iteritems(prms):
for v in self.name_map[name]: assert name in self.name_map
varshape = tuple(v.get_shape().as_list()) for v in self.name_map[name]:
if varshape != value.shape: SessionUpdate.load_value_to_var(v, value)
# TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape)
v.load(value, session=self.sess)
def dump_session_params(path): def dump_session_params(path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment