Commit b9e79e1c authored by Yuxin Wu's avatar Yuxin Wu

add VariableHolder for BN

parent 0e4ddfd6
...@@ -9,7 +9,7 @@ from tensorflow.python.training import moving_averages ...@@ -9,7 +9,7 @@ from tensorflow.python.training import moving_averages
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from .common import layer_register from .common import layer_register, VariableHolder
__all__ = ['BatchNorm', 'BatchRenorm'] __all__ = ['BatchNorm', 'BatchRenorm']
...@@ -220,9 +220,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -220,9 +220,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
# maintain EMA only on one GPU. # maintain EMA only on one GPU.
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else: else:
return tf.identity(xn, name='output') ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var)
if use_scale:
vh.gamma = gamma
if use_bias:
vh.beta = beta
return ret
# TODO support NCHW # TODO support NCHW
......
...@@ -31,7 +31,6 @@ class VariableHolder(object): ...@@ -31,7 +31,6 @@ class VariableHolder(object):
self._add_variable(k, v) self._add_variable(k, v)
def _add_variable(self, name, var): def _add_variable(self, name, var):
print(name, var.name)
assert name not in self._vars assert name not in self._vars
self._vars[name] = var self._vars[name] = var
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment