Commit 29ad5d2a authored by Yuxin Wu's avatar Yuxin Wu

replace zeros_initializer for compatibility

parent 99c70935
......@@ -27,8 +27,8 @@ class Model(ModelDesc):
with tf.variable_scope(name) as scope:
l = Conv2D('convfc', l, 1, kernel_shape=1, nl=tf.identity,
use_bias=True,
W_init=tf.zeros_initializer,
b_init=tf.zeros_initializer)
W_init=tf.constant_initializer(),
b_init=tf.constant_initializer())
while up != 1:
l = BilinearUpSample('upsample{}'.format(up), l, 2)
up = up / 2
......
......@@ -39,7 +39,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
n_out = shape[-1] # channel
assert n_out is not None
beta = tf.get_variable('beta', [n_out],
initializer=tf.zeros_initializer)
initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
......@@ -131,7 +131,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x = tf.reshape(x, [-1, 1, 1, n_out])
beta = tf.get_variable('beta', [n_out],
initializer=tf.zeros_initializer)
initializer=tf.constant_initializer())
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
# x * gamma + beta
......@@ -143,9 +143,9 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
logger.warn("[BatchNorm] use_local_stat != is_training")
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.zeros_initializer, trainable=False)
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.zeros_initializer, trainable=False)
initializer=tf.constant_initializer(), trainable=False)
if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
......
......@@ -48,7 +48,7 @@ def get_global_step_var():
"Creating global_step_var under a variable scope would cause problems!"
with tf.variable_scope(scope, reuse=False):
var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[],
initializer=tf.zeros_initializer,
initializer=tf.constant_initializer(dtype=tf.int32),
trainable=False, dtype=tf.int32)
return var
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment