Commit 6dc04278 authored by Yuxin Wu's avatar Yuxin Wu

fix BN again and mute some compatibility noise

parent 6e3e0115
......@@ -18,13 +18,18 @@ class ModelSaver(Callback):
Save the model to logger directory.
"""
def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=tf.GraphKeys().VARIABLES):
var_collections=None):
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
"""
self.keep_recent = keep_recent
self.keep_freq = keep_freq
if var_collections is None:
try:
var_collections = tf.GraphKeys.GLOBAL_VARIABLES
except:
var_collections = tf.GraphKeys.VARIABLES
if not isinstance(var_collections, list):
var_collections = [var_collections]
self.var_collections = var_collections
......
......@@ -44,7 +44,12 @@ class GraphVarParam(HyperParam):
self._readable_name, self.var_name = get_op_var_name(name)
def setup_graph(self):
try:
all_vars = tf.global_variables()
except:
# TODO
all_vars = tf.all_variables()
for v in all_vars:
if v.name == self.var_name:
self.var = v
......
......@@ -60,7 +60,9 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if use_local_stat:
# training tower
if ctx.is_training:
reuse = tf.get_variable_scope().reuse
#reuse = tf.get_variable_scope().reuse
with tf.variable_scope(tf.get_variable_scope(), reuse=False):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
# TODO if reuse=True, try to find and use the existing statistics
# how to use multiple tensors to update one EMA? seems impossbile
......@@ -72,7 +74,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
add_model_variable(ema_mean)
add_model_variable(ema_var)
else:
# no apply() is called here, no magic vars will get created
# no apply() is called here, no magic vars will get created,
# no reuse issue will happen
assert not ctx.is_training
with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
......@@ -81,14 +84,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
sc = tf.get_variable_scope()
if ctx.is_main_tower:
# main tower, but needs to use global stat. global stat must be from outside
# TODO when reuse=True, the variable name could actually be different
# TODO when reuse=True, the desired variable name could
# actually be different, because a different var is created
# for different reuse tower
ema_mean = tf.get_variable('mean/' + emaname, [n_out])
ema_var = tf.get_variable('variance/' + emaname, [n_out])
else:
## use statistics in another tower
G = tf.get_default_graph()
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name)
ema_var = ctx.find_tensor_in_main_tower(G, var_var_name)
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name + ':0')
ema_var = ctx.find_tensor_in_main_tower(G, var_var_name + ':0')
if use_local_stat:
batch = tf.cast(tf.shape(x)[0], tf.float32)
......
......@@ -111,7 +111,11 @@ class Trainer(object):
logger.info("Initializing graph variables ...")
# TODO newsession + sessinit?
self.sess.run(tf.initialize_all_variables())
try:
initop = tf.global_variables_initializer()
except:
initop = tf.initialize_all_variables()
self.sess.run(initop)
self.config.session_init.init(self.sess)
tf.get_default_graph().finalize()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment