Commit 6dc04278 authored by Yuxin Wu's avatar Yuxin Wu

fix BN again and mute some compatibility noise

parent 6e3e0115
...@@ -18,13 +18,18 @@ class ModelSaver(Callback): ...@@ -18,13 +18,18 @@ class ModelSaver(Callback):
Save the model to logger directory. Save the model to logger directory.
""" """
def __init__(self, keep_recent=10, keep_freq=0.5, def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=tf.GraphKeys().VARIABLES): var_collections=None):
""" """
:param keep_recent: see `tf.train.Saver` documentation. :param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation. :param keep_freq: see `tf.train.Saver` documentation.
""" """
self.keep_recent = keep_recent self.keep_recent = keep_recent
self.keep_freq = keep_freq self.keep_freq = keep_freq
if var_collections is None:
try:
var_collections = tf.GraphKeys.GLOBAL_VARIABLES
except:
var_collections = tf.GraphKeys.VARIABLES
if not isinstance(var_collections, list): if not isinstance(var_collections, list):
var_collections = [var_collections] var_collections = [var_collections]
self.var_collections = var_collections self.var_collections = var_collections
......
...@@ -44,7 +44,12 @@ class GraphVarParam(HyperParam): ...@@ -44,7 +44,12 @@ class GraphVarParam(HyperParam):
self._readable_name, self.var_name = get_op_var_name(name) self._readable_name, self.var_name = get_op_var_name(name)
def setup_graph(self): def setup_graph(self):
all_vars = tf.all_variables() try:
all_vars = tf.global_variables()
except:
# TODO
all_vars = tf.all_variables()
for v in all_vars: for v in all_vars:
if v.name == self.var_name: if v.name == self.var_name:
self.var = v self.var = v
......
...@@ -60,19 +60,22 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -60,19 +60,22 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if use_local_stat: if use_local_stat:
# training tower # training tower
if ctx.is_training: if ctx.is_training:
reuse = tf.get_variable_scope().reuse #reuse = tf.get_variable_scope().reuse
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740 with tf.variable_scope(tf.get_variable_scope(), reuse=False):
# TODO if reuse=True, try to find and use the existing statistics # BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
# how to use multiple tensors to update one EMA? seems impossbile with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) # TODO if reuse=True, try to find and use the existing statistics
ema_apply_op = ema.apply([batch_mean, batch_var]) # how to use multiple tensors to update one EMA? seems impossbile
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
if ctx.is_main_training_tower: ema_apply_op = ema.apply([batch_mean, batch_var])
# inside main training tower ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
add_model_variable(ema_mean) if ctx.is_main_training_tower:
add_model_variable(ema_var) # inside main training tower
add_model_variable(ema_mean)
add_model_variable(ema_var)
else: else:
# no apply() is called here, no magic vars will get created # no apply() is called here, no magic vars will get created,
# no reuse issue will happen
assert not ctx.is_training assert not ctx.is_training
with tf.name_scope(None): with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
...@@ -81,14 +84,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -81,14 +84,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
sc = tf.get_variable_scope() sc = tf.get_variable_scope()
if ctx.is_main_tower: if ctx.is_main_tower:
# main tower, but needs to use global stat. global stat must be from outside # main tower, but needs to use global stat. global stat must be from outside
# TODO when reuse=True, the variable name could actually be different # TODO when reuse=True, the desired variable name could
# actually be different, because a different var is created
# for different reuse tower
ema_mean = tf.get_variable('mean/' + emaname, [n_out]) ema_mean = tf.get_variable('mean/' + emaname, [n_out])
ema_var = tf.get_variable('variance/' + emaname, [n_out]) ema_var = tf.get_variable('variance/' + emaname, [n_out])
else: else:
## use statistics in another tower ## use statistics in another tower
G = tf.get_default_graph() G = tf.get_default_graph()
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name) ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name + ':0')
ema_var = ctx.find_tensor_in_main_tower(G, var_var_name) ema_var = ctx.find_tensor_in_main_tower(G, var_var_name + ':0')
if use_local_stat: if use_local_stat:
batch = tf.cast(tf.shape(x)[0], tf.float32) batch = tf.cast(tf.shape(x)[0], tf.float32)
......
...@@ -111,7 +111,11 @@ class Trainer(object): ...@@ -111,7 +111,11 @@ class Trainer(object):
logger.info("Initializing graph variables ...") logger.info("Initializing graph variables ...")
# TODO newsession + sessinit? # TODO newsession + sessinit?
self.sess.run(tf.initialize_all_variables()) try:
initop = tf.global_variables_initializer()
except:
initop = tf.initialize_all_variables()
self.sess.run(initop)
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
tf.get_default_graph().finalize() tf.get_default_graph().finalize()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment