Commit 7a110067 authored by Yuxin Wu's avatar Yuxin Wu

update bn to use towercontext

parent 7a0e8747
......@@ -18,7 +18,7 @@ __all__ = ['BatchNorm']
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4
@layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
"""
Batch normalization layer as described in:
......@@ -30,8 +30,9 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
* Whole-population mean/variance is calculated by a running-average mean/variance.
* Epsilon for variance is set to 1e-5, as is `torch/nn <https://github.com/torch/nn/blob/master/BatchNormalization.lua>`_.
:param input: a NHWC tensor or a NC vector
:param use_local_stat: bool. whether to use mean/var of this batch or the moving average. Set to True in training and False in testing
:param input: a NHWC or NC tensor
:param use_local_stat: bool. whether to use mean/var of this batch or the moving average.
Default to True in training and False in predicting.
:param decay: decay rate. default to 0.999.
:param epsilon: default to 1e-5.
"""
......@@ -53,41 +54,34 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
batch_mean = tf.identity(batch_mean, 'mean')
batch_var = tf.identity(batch_var, 'variance')
# XXX a hack to handle training tower & prediction tower together....
emaname = 'EMA'
#ctx = get_current_model_context()
if not batch_mean.name.startswith('towerp'):
ctx = get_current_model_context()
if use_local_stat is None:
use_local_stat = ctx.is_training
assert use_local_stat == ctx.is_training
if ctx.is_training:
# training tower
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
if not batch_mean.name.startswith('tower') or \
batch_mean.name.startswith('tower0'):
if ctx.is_main_training_tower:
# inside main training tower
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_mean)
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var)
else:
# use training-statistics in prediction
assert not use_local_stat
with tf.name_scope(None):
# figure out the var name
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
mean_var_name = ema.average_name(batch_mean) + ':0'
var_var_name = ema.average_name(batch_var) + ':0'
# use statistics in another tower
G = tf.get_default_graph()
# find training statistics in training tower
try:
mean_name = re.sub('towerp[0-9]+/', '', mean_var_name)
var_name = re.sub('towerp[0-9]+/', '', var_var_name)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
except KeyError:
mean_name = re.sub('towerp[0-9]+/', 'tower0/', mean_var_name)
var_name = re.sub('towerp[0-9]+/', 'tower0/', var_var_name)
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name)
ema_var = ctx.find_tensor_in_main_tower(G, var_var_name)
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
......
......@@ -26,6 +26,10 @@ class TowerContext(object):
is_training = not self._name.startswith('towerp')
self._is_training = is_training
@property
def is_main_training_tower(self):
return self.is_training and (self._name == '' or self._name == 'tower0')
@property
def is_main_tower(self):
return self._name == '' or self._name == 'tower0'
......@@ -34,6 +38,17 @@ class TowerContext(object):
def is_training(self):
return self._is_training
def find_tensor_in_main_tower(self, graph, name):
if self.is_main_tower:
return graph.get_tensor_by_name(name)
if name.startswith('towerp'):
newname = re.sub('towerp[0-9]+/', '', name)
try:
return graph.get_tensor_by_name(newname)
except KeyError:
newname = re.sub('towerp[0-9]+/', 'tower0/', name)
return graph.get_tensor_by_name(newname)
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment