Commit 81bb9ac2 authored by Yuxin Wu's avatar Yuxin Wu

use a better bn variable name

parent 12d27154
...@@ -48,17 +48,20 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -48,17 +48,20 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False)
else: else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False)
# just to make a clear name.
batch_mean = tf.identity(batch_mean, 'mean')
batch_var = tf.identity(batch_var, 'variance')
emaname = 'EMA' emaname = 'EMA'
in_train_tower = not batch_mean.name.startswith('towerp') in_main_tower = not batch_mean.name.startswith('towerp')
if in_train_tower: if in_main_tower:
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else: else:
# use training-statistics in prediction # use training-statistics in prediction
assert not use_local_stat assert not use_local_stat
# have to do this again to get actual name. see issue: # XXX have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740 # https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
......
...@@ -132,7 +132,8 @@ class ParamRestore(SessionInit): ...@@ -132,7 +132,8 @@ class ParamRestore(SessionInit):
def _init(self, sess): def _init(self, sess):
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # allow restore non-trainable variables
variables = tf.get_collection(tf.GraphKeys.VARIABLES)
var_dict = dict([v.name, v] for v in variables) var_dict = dict([v.name, v] for v in variables)
for name, value in six.iteritems(self.prms): for name, value in six.iteritems(self.prms):
if not name.endswith(':0'): if not name.endswith(':0'):
...@@ -145,7 +146,8 @@ class ParamRestore(SessionInit): ...@@ -145,7 +146,8 @@ class ParamRestore(SessionInit):
logger.info("Restoring param {}".format(name)) logger.info("Restoring param {}".format(name))
varshape = tuple(var.get_shape().as_list()) varshape = tuple(var.get_shape().as_list())
if varshape != value.shape: if varshape != value.shape:
assert np.prod(varshape) == np.prod(value.shape) assert np.prod(varshape) == np.prod(value.shape), \
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during loading!".format(name)) logger.warn("Param {} is reshaped during loading!".format(name))
value = value.reshape(varshape) value = value.reshape(varshape)
sess.run(var.assign(value)) sess.run(var.assign(value))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment