Commit 13e0ec39 authored by Yuxin Wu's avatar Yuxin Wu

Use NHWC in batch_norm when shape is 2d. fix #190

parent 8f917c01
......@@ -167,13 +167,13 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
if len(shape) == 2:
data_format = 'NCHW'
data_format = 'NHWC' # error using NCHW? (see #190)
if data_format == 'NCHW':
n_out = shape[1]
else:
n_out = shape[-1] # channel
if len(shape) == 2:
x = tf.reshape(x, [-1, n_out, 1, 1])
x = tf.reshape(x, [-1, 1, 1, n_out])
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias)
......@@ -213,6 +213,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
return tf.identity(xn, name='output')
# TODO support NCHW
@layer_register(log_shape=False)
def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment