Commit 6e3e0115 authored by Yuxin Wu's avatar Yuxin Wu

fix bn under latest TF

parent 5cccf2b8
...@@ -45,7 +45,6 @@ class PennTreeBank(RNGDataFlow): ...@@ -45,7 +45,6 @@ class PennTreeBank(RNGDataFlow):
super(PennTreeBank, self).__init__() super(PennTreeBank, self).__init__()
if data_dir is None: if data_dir is None:
data_dir = get_dataset_path('ptb_data') data_dir = get_dataset_path('ptb_data')
assert os.path.isdir(data_dir)
data3, word_to_id = get_raw_data(data_dir) data3, word_to_id = get_raw_data(data_dir)
self.word_to_id = word_to_id self.word_to_id = word_to_id
self.data = np.asarray( self.data = np.asarray(
......
...@@ -40,7 +40,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -40,7 +40,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
beta = tf.get_variable('beta', [n_out], beta = tf.get_variable('beta', [n_out],
initializer=tf.zeros_initializer) initializer=tf.zeros_initializer)
gamma = tf.get_variable('gamma', [n_out], gamma = tf.get_variable('gamma', [n_out],
initializer=tf.ones_initializer) initializer=tf.constant_initializer(1.0))
if len(shape) == 2: if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False)
...@@ -59,7 +59,11 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -59,7 +59,11 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if use_local_stat: if use_local_stat:
# training tower # training tower
if ctx.is_training:
reuse = tf.get_variable_scope().reuse
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740 with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
# TODO if reuse=True, try to find and use the existing statistics
# how to use multiple tensors to update one EMA? seems impossbile
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
...@@ -68,30 +72,31 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -68,30 +72,31 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
add_model_variable(ema_mean) add_model_variable(ema_mean)
add_model_variable(ema_var) add_model_variable(ema_var)
else: else:
if ctx.is_main_tower: # no apply() is called here, no magic vars will get created
# not training, but main tower. need to create the vars assert not ctx.is_training
with tf.name_scope(None): with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var]) mean_var_name = ema.average_name(batch_mean)
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) var_var_name = ema.average_name(batch_var)
sc = tf.get_variable_scope()
if ctx.is_main_tower:
# main tower, but needs to use global stat. global stat must be from outside
# TODO when reuse=True, the variable name could actually be different
ema_mean = tf.get_variable('mean/' + emaname, [n_out])
ema_var = tf.get_variable('variance/' + emaname, [n_out])
else: else:
# use statistics in another tower ## use statistics in another tower
G = tf.get_default_graph() G = tf.get_default_graph()
# figure out the var name
with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
mean_var_name = ema.average_name(batch_mean) + ':0'
var_var_name = ema.average_name(batch_var) + ':0'
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name) ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name)
ema_var = ctx.find_tensor_in_main_tower(G, var_var_name) ema_var = ctx.find_tensor_in_main_tower(G, var_var_name)
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
if use_local_stat: if use_local_stat:
with tf.control_dependencies([ema_apply_op]):
batch = tf.cast(tf.shape(x)[0], tf.float32) batch = tf.cast(tf.shape(x)[0], tf.float32)
mul = tf.select(tf.equal(batch, 1.0), 1.0, batch / (batch - 1)) mul = tf.select(tf.equal(batch, 1.0), 1.0, batch / (batch - 1))
batch_var = batch_var * mul # use unbiased variance estimator in training batch_var = batch_var * mul # use unbiased variance estimator in training
with tf.control_dependencies([ema_apply_op] if ctx.is_training else []):
# only apply EMA op if is_training
return tf.nn.batch_normalization( return tf.nn.batch_normalization(
x, batch_mean, batch_var, beta, gamma, epsilon, 'output') x, batch_mean, batch_var, beta, gamma, epsilon, 'output')
else: else:
......
...@@ -110,6 +110,7 @@ class Trainer(object): ...@@ -110,6 +110,7 @@ class Trainer(object):
self.stat_holder = StatHolder(logger.LOG_DIR) self.stat_holder = StatHolder(logger.LOG_DIR)
logger.info("Initializing graph variables ...") logger.info("Initializing graph variables ...")
# TODO newsession + sessinit?
self.sess.run(tf.initialize_all_variables()) self.sess.run(tf.initialize_all_variables())
self.config.session_init.init(self.sess) self.config.session_init.init(self.sess)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment