Commit f60989d3 authored by Yuxin Wu's avatar Yuxin Wu

fix bn and rewrite saverrestore with var.load

parent 3f238a01
......@@ -118,7 +118,7 @@ def get_bn_variables(x, use_scale, use_bias):
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
return beta, gamma, moving_mean, moving_var
return x, beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
......@@ -171,7 +171,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
with the official inceptionv3 example).
"""
shape = x.get_shape().as_list()
beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias)
x, beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias)
ctx = get_current_tower_context()
if use_local_stat is None:
......@@ -231,7 +231,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
"""
shape = x.get_shape().as_list()
beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias)
x, beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias)
ctx = get_current_tower_context()
use_local_stat = ctx.is_training
......
......@@ -60,8 +60,8 @@ def get_global_step_var():
with tf.variable_scope(scope, reuse=False), \
tf.name_scope(None):
var = tf.get_variable(GLOBAL_STEP_OP_NAME,
initializer=0,
trainable=False, dtype=tf.int32)
initializer=tf.constant(0, dtype=tf.int64),
trainable=False, dtype=tf.int64)
return var
......
......@@ -6,6 +6,7 @@
import tensorflow as tf
from contextlib import contextmanager
from .gradproc import apply_grad_processors as apply_gradproc
from .gradproc import FilterNoneGrad
__all__ = ['apply_grad_processors', 'ProxyOptimizer',
'PostProcessOptimizer', 'VariableAssignmentOptimizer']
......@@ -115,3 +116,59 @@ class VariableAssignmentOptimizer(PostProcessOptimizer):
return t
return tf.assign(v, t, use_locking=False).op
super(VariableAssignmentOptimizer, self).__init__(opt, f)
class AccumGradOptimizer(ProxyOptimizer):
def __init__(self, opt, niter):
super(AccumGradOptimizer, self).__init__(opt)
self._niter = niter
self._name = "AccumGrad"
self._counter = None
def _create_accum_slots(self, var_list):
slots = []
for v in var_list:
s = self._zeros_slot(v, "accum", self._name)
slots.append(s)
return slots
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
assert global_step is None, \
"AccumGradOptimizer doesn't support the option global_step! " \
"Please maintain it yourself."
grads_and_vars = FilterNoneGrad().process(grads_and_vars)
vs = []
for g, v in grads_and_vars:
assert isinstance(g, tf.Tensor) and isinstance(v, tf.Variable), \
"AccumGradOptimizer only works for dense update! " \
"Types of v and g are {} and {}".format(type(v), type(g))
vs.append(v)
with tf.control_dependencies(None):
slots = self._create_accum_slots(vs)
slots_and_vars = [(s, gv[1]) for s, gv in zip(slots, grads_and_vars)]
# Create the counter on the same device as the first variable.
with tf.variable_scope(self._name), \
tf.colocate_with(vs[0]):
counter = tf.Variable(
0, name="counter", trainable=False, dtype=tf.int32)
ops = []
for s, gv in zip(slots, grads_and_vars):
g, v = gv
ops.append(s.assign_add(s, g))
update_counter = tf.assign_add(counter, 1, name='update_counter')
update_slot_op = tf.group(update_counter, *ops, name='update_slot')
def update_grad():
update_op = self._opt.apply_gradients(slots_and_vars)
with tf.control_dependencies([update_op]):
clear_ops = [tf.assign(s, 0.0) for s in slots]
return tf.group(*clear_ops, name='update_grad')
pred = tf.equal(tf.mod(counter, self._niter), 0)
with tf.control_dependencies([update_slot_op]):
if name is None:
name = 'cond_update_grad'
op = tf.cond(pred, update_grad, lambda: tf.no_op(), name=name)
return op
......@@ -4,7 +4,6 @@
import os
from abc import abstractmethod, ABCMeta
from collections import defaultdict
import numpy as np
import tensorflow as tf
import six
......@@ -57,7 +56,7 @@ class NewSession(SessionInit):
class SaverRestore(SessionInit):
"""
Restore an old model saved by :class:`ModelSaver`.
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
"""
def __init__(self, model_path, prefix=None):
......@@ -73,28 +72,26 @@ class SaverRestore(SessionInit):
def _init(self, sess):
logger.info(
"Restoring checkpoint from {} ...".format(self.path))
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = self._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map):
# multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
saver = tf.train.Saver(var_list=dic, name=str(id(dic)), write_version=2)
saver.restore(sess, self.path)
reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
graph_vars = tf.global_variables()
chkpt_vars_used = set()
@staticmethod
def _produce_restore_dict(vars_multimap):
"""
Produce {var_name: var} dict that can be used by `tf.train.Saver`, from a {var_name: [vars]} dict.
"""
while len(vars_multimap):
ret = {}
for k in list(vars_multimap.keys()):
v = vars_multimap[k]
ret[k] = v[-1]
del v[-1]
if not len(v):
del vars_multimap[k]
yield ret
with sess.as_default():
for v in graph_vars:
name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
if name in chkpt_vars:
val = reader.get_tensor(name)
SessionUpdate.load_value_to_var(v, val)
chkpt_vars_used.add(name)
else:
vname = v.op.name
if not is_training_name(vname):
logger.warn("Variable {} in the graph not found in checkpoint!".format(vname))
if len(chkpt_vars_used) < len(chkpt_vars):
unused = chkpt_vars - chkpt_vars_used
for name in sorted(unused):
if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
@staticmethod
def _read_checkpoint_vars(model_path):
......@@ -105,37 +102,7 @@ class SaverRestore(SessionInit):
if v.startswith(PREDICT_TOWER):
logger.error("Found {} in checkpoint. "
"But anything from prediction tower shouldn't be saved.".format(v.name))
return set(ckpt_vars)
def _get_vars_to_restore_multimap(self, vars_available):
"""
:param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
"""
vars_to_restore = tf.global_variables()
var_dict = defaultdict(list)
chkpt_vars_used = set()
for v in vars_to_restore:
name = get_savename_from_varname(v.name, varname_prefix=self.prefix)
# try to load both 'varname' and 'opname' from checkpoint
# because some old checkpoint might not have ':0'
if name in vars_available:
var_dict[name].append(v)
chkpt_vars_used.add(name)
elif name.endswith(':0'):
name = name[:-2]
if name in vars_available:
var_dict[name].append(v)
chkpt_vars_used.add(name)
else:
if not is_training_name(v.op.name):
logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used
for name in sorted(unused):
if not is_training_name(name):
logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
return var_dict
return reader, set(ckpt_vars)
class ParamRestore(SessionInit):
......
......@@ -59,23 +59,64 @@ class SessionUpdate(object):
savename = get_savename_from_varname(v.name)
self.name_map[savename].append(v)
@staticmethod
def load_value_to_var(var, val, strict=False):
"""
Call `var.load(val)` with the default session.
Args:
var (tf.Variable):
strict (bool): Behave less strict if set to False.
"""
if strict:
var.load(val)
return
name = var.op.name
# check incompatible shape
varshape = tuple(var.get_shape().as_list())
if varshape != val.shape:
# TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(val.shape), \
"{}: {}!={}".format(name, varshape, val.shape)
logger.warn("Variable {} is reshaped during assigning".format(name))
val = val.reshape(varshape)
# fix some common type incompatibility problem, but is certainly not enough
def upcast(vartype, valtype):
# allow up-casting
if vartype == tf.float64 and valtype == np.float32:
return np.float64
if vartype in [tf.int64, tf.int32] and valtype in [np.int32, np.int16, np.int8]:
return np.int64 if vartype == tf.int64 else np.int32
return None
if hasattr(val, 'dtype'):
vartype = var.value().dtype
if vartype != val.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, val.dtype)
newtype = upcast(var.dtype, val.dtype)
if newtype is not None:
val = newtype(val)
logger.warn(msg + " Load it after casting!")
else:
assert vartype == val.dtype, msg
try:
var.load(val)
except tf.errors.InvalidArgumentError:
logger.exc("Cannot load this value to the variable {}".format(name))
def update(self, prms):
"""
Args:
prms(dict): dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
"""
for name, value in six.iteritems(prms):
assert name in self.name_map
for v in self.name_map[name]:
varshape = tuple(v.get_shape().as_list())
if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape)
v.load(value, session=self.sess)
with self.sess.as_default():
for name, value in six.iteritems(prms):
assert name in self.name_map
for v in self.name_map[name]:
SessionUpdate.load_value_to_var(v, value)
def dump_session_params(path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment