Commit 29ad5d2a authored by Yuxin Wu's avatar Yuxin Wu

replace zeros_initializer for compatibility

parent 99c70935
...@@ -27,8 +27,8 @@ class Model(ModelDesc): ...@@ -27,8 +27,8 @@ class Model(ModelDesc):
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
l = Conv2D('convfc', l, 1, kernel_shape=1, nl=tf.identity, l = Conv2D('convfc', l, 1, kernel_shape=1, nl=tf.identity,
use_bias=True, use_bias=True,
W_init=tf.zeros_initializer, W_init=tf.constant_initializer(),
b_init=tf.zeros_initializer) b_init=tf.constant_initializer())
while up != 1: while up != 1:
l = BilinearUpSample('upsample{}'.format(up), l, 2) l = BilinearUpSample('upsample{}'.format(up), l, 2)
up = up / 2 up = up / 2
......
...@@ -39,7 +39,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -39,7 +39,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
n_out = shape[-1] # channel n_out = shape[-1] # channel
assert n_out is not None assert n_out is not None
beta = tf.get_variable('beta', [n_out], beta = tf.get_variable('beta', [n_out],
initializer=tf.zeros_initializer) initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out], gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0)) initializer=tf.constant_initializer(1.0))
...@@ -131,7 +131,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -131,7 +131,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x = tf.reshape(x, [-1, 1, 1, n_out]) x = tf.reshape(x, [-1, 1, 1, n_out])
beta = tf.get_variable('beta', [n_out], beta = tf.get_variable('beta', [n_out],
initializer=tf.zeros_initializer) initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out], gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0)) initializer=tf.constant_initializer(1.0))
# x * gamma + beta # x * gamma + beta
...@@ -143,9 +143,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -143,9 +143,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
logger.warn("[BatchNorm] use_local_stat != is_training") logger.warn("[BatchNorm] use_local_stat != is_training")
moving_mean = tf.get_variable('mean/EMA', [n_out], moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.zeros_initializer, trainable=False) initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out], moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.zeros_initializer, trainable=False) initializer=tf.constant_initializer(), trainable=False)
if use_local_stat: if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta, xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
......
...@@ -48,7 +48,7 @@ def get_global_step_var(): ...@@ -48,7 +48,7 @@ def get_global_step_var():
"Creating global_step_var under a variable scope would cause problems!" "Creating global_step_var under a variable scope would cause problems!"
with tf.variable_scope(scope, reuse=False): with tf.variable_scope(scope, reuse=False):
var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[], var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[],
initializer=tf.zeros_initializer, initializer=tf.constant_initializer(dtype=tf.int32),
trainable=False, dtype=tf.int32) trainable=False, dtype=tf.int32)
return var return var
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment