Commit 02e94b63 authored by Yuxin Wu's avatar Yuxin Wu

gamma_init in batch_norm

parent fd635774
...@@ -39,7 +39,7 @@ if args.output: ...@@ -39,7 +39,7 @@ if args.output:
for bi, img in enumerate(imgbatch): for bi, img in enumerate(imgbatch):
cnt += 1 cnt += 1
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi)) fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
cv2.imwrite(fname, img) cv2.imwrite(fname, img * 255)
NR_DP_TEST = 100 NR_DP_TEST = 100
logger.info("Testing dataflow speed:") logger.info("Testing dataflow speed:")
......
...@@ -13,7 +13,7 @@ __all__ = ['BatchNorm'] ...@@ -13,7 +13,7 @@ __all__ = ['BatchNorm']
# http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow # http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
# TF batch_norm only works for 4D tensor right now: #804 # TF batch_norm only works for 4D tensor right now: #804
@layer_register() @layer_register()
def BatchNorm(x, is_training): def BatchNorm(x, is_training, gamma_init=1.0):
""" """
Batch normalization layer as described in: Batch normalization layer as described in:
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
...@@ -35,7 +35,8 @@ def BatchNorm(x, is_training): ...@@ -35,7 +35,8 @@ def BatchNorm(x, is_training):
n_out = shape[-1] # channel n_out = shape[-1] # channel
beta = tf.get_variable('beta', [n_out]) beta = tf.get_variable('beta', [n_out])
gamma = tf.get_variable('gamma', [n_out], initializer=tf.constant_initializer(1.0)) gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(gamma_init))
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.999) ema = tf.train.ExponentialMovingAverage(decay=0.999)
...@@ -49,6 +50,7 @@ def BatchNorm(x, is_training): ...@@ -49,6 +50,7 @@ def BatchNorm(x, is_training):
batch = tf.cast(tf.shape(x)[0], tf.float32) batch = tf.cast(tf.shape(x)[0], tf.float32)
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
normed = tf.nn.batch_norm_with_global_normalization(x, mean, var, beta, gamma, EPS, True) normed = tf.nn.batch_norm_with_global_normalization(
x, mean, var, beta, gamma, EPS, True)
return normed return normed
...@@ -18,6 +18,7 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu): ...@@ -18,6 +18,7 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
if W_init is None: if W_init is None:
W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim))) W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim)))
#W_init = tf.uniform_unit_scaling_initializer()
if b_init is None: if b_init is None:
b_init = tf.constant_initializer(0.0) b_init = tf.constant_initializer(0.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment