Commit aa124d20 authored by Yuxin Wu's avatar Yuxin Wu

bug fix about model saving

parent 3448ffd8
...@@ -43,7 +43,11 @@ with tf.Graph().as_default() as G: ...@@ -43,7 +43,11 @@ with tf.Graph().as_default() as G:
else: else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY)) var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
var_dict = {}
for v in var:
name = varmanip.get_savename_from_varname(v.name)
var_dict[name] = v
logger.info("Variables to dump:") logger.info("Variables to dump:")
logger.info(", ".join([v.name for v in var])) logger.info(", ".join(var_dict.keys()))
saver = tf.train.Saver(var_list=var) saver = tf.train.Saver(var_list=var_dict)
saver.save(sess, args.output, write_meta_graph=False) saver.save(sess, args.output, write_meta_graph=False)
...@@ -55,12 +55,16 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -55,12 +55,16 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
# XXX a hack to handle training tower & prediction tower together.... # XXX a hack to handle training tower & prediction tower together....
emaname = 'EMA' emaname = 'EMA'
if not batch_mean.name.startswith('towerp'): if not batch_mean.name.startswith('towerp'):
# training tower
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740 with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_mean) if not batch_mean.name.startswith('tower') or \
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var) batch_mean.name.startswith('tower0'):
# inside main training tower
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_mean)
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var)
else: else:
# use training-statistics in prediction # use training-statistics in prediction
assert not use_local_stat assert not use_local_stat
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment