Commit aa124d20 authored by Yuxin Wu's avatar Yuxin Wu

bug fix about model saving

parent 3448ffd8
......@@ -43,7 +43,11 @@ with tf.Graph().as_default() as G:
else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
var_dict = {}
for v in var:
name = varmanip.get_savename_from_varname(v.name)
var_dict[name] = v
logger.info("Variables to dump:")
logger.info(", ".join([v.name for v in var]))
saver = tf.train.Saver(var_list=var)
logger.info(", ".join(var_dict.keys()))
saver = tf.train.Saver(var_list=var_dict)
saver.save(sess, args.output, write_meta_graph=False)
......@@ -55,12 +55,16 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
# XXX a hack to handle training tower & prediction tower together....
emaname = 'EMA'
if not batch_mean.name.startswith('towerp'):
# training tower
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_mean)
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var)
if not batch_mean.name.startswith('tower') or \
batch_mean.name.startswith('tower0'):
# inside main training tower
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_mean)
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var)
else:
# use training-statistics in prediction
assert not use_local_stat
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment