Commit 249052e0 authored by Yuxin Wu's avatar Yuxin Wu

add variable in batchrenorm. fix bug in CacheData

parent b9e79e1c
......@@ -566,6 +566,7 @@ class CacheData(ProxyDataFlow):
def get_data(self):
if len(self.buffer):
if self.shuffle:
self.rng.shuffle(self.buffer)
for dp in self.buffer:
yield dp
......
......@@ -232,7 +232,6 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
return ret
# TODO support NCHW
@layer_register(log_shape=False)
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, data_format='NHWC'):
......@@ -308,6 +307,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
x, moving_mean, moving_var, beta, gamma, epsilon)
if ctx.is_main_training_tower:
return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
return tf.identity(xn, name='output')
ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var)
if use_scale:
vh.gamma = gamma
if use_bias:
vh.beta = beta
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment