Commit 249052e0 authored by Yuxin Wu's avatar Yuxin Wu

add variable in batchrenorm. fix bug in CacheData

parent b9e79e1c
...@@ -566,7 +566,8 @@ class CacheData(ProxyDataFlow): ...@@ -566,7 +566,8 @@ class CacheData(ProxyDataFlow):
def get_data(self): def get_data(self):
if len(self.buffer): if len(self.buffer):
self.rng.shuffle(self.buffer) if self.shuffle:
self.rng.shuffle(self.buffer)
for dp in self.buffer: for dp in self.buffer:
yield dp yield dp
else: else:
......
...@@ -232,7 +232,6 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -232,7 +232,6 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
return ret return ret
# TODO support NCHW
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, data_format='NHWC'): use_scale=True, use_bias=True, data_format='NHWC'):
...@@ -308,6 +307,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -308,6 +307,13 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
x, moving_mean, moving_var, beta, gamma, epsilon) x, moving_mean, moving_var, beta, gamma, epsilon)
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else: else:
return tf.identity(xn, name='output') ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var)
if use_scale:
vh.gamma = gamma
if use_bias:
vh.beta = beta
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment