Commit c759f211 authored by Yuxin Wu's avatar Yuxin Wu

use official batchnorm op

parent df89a95f
......@@ -19,8 +19,7 @@ from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
"""
CIFAR10 90% validation accuracy after 100k step.
91% after 160k step
CIFAR10 90% validation accuracy after 70k step.
"""
BATCH_SIZE = 128
......@@ -43,6 +42,7 @@ class Model(ModelDesc):
num_threads=6, enqueue_many=True)
tf.image_summary("train_image", image, 10)
image = image / 4.0 # just to make range smaller
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=tf.identity)
l = BatchNorm('bn1', l, is_training)
......@@ -112,7 +112,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 3, 2)
ds = PrefetchData(ds, 10, 5)
return ds
......@@ -120,7 +120,7 @@ def get_data(train_or_test):
def get_config():
# prepare dataset
dataset_train = get_data('train')
step_per_epoch = dataset_train.size() / 2
step_per_epoch = dataset_train.size()
dataset_test = get_data('test')
sess_config = get_default_sess_config()
......
......@@ -31,10 +31,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
"""
shape = x.get_shape().as_list()
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, shape[1]])
shape = x.get_shape().as_list()
assert len(shape) == 4
assert len(shape) in [2, 4]
n_out = shape[-1] # channel
beta = tf.get_variable('beta', [n_out])
......@@ -42,7 +39,10 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
'gamma', [n_out],
initializer=tf.constant_initializer(1.0))
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], name='moments', keep_dims=False)
else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments', keep_dims=False)
ema = tf.train.ExponentialMovingAverage(decay=decay)
ema_apply_op = ema.apply([batch_mean, batch_var])
......@@ -50,10 +50,10 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
if use_local_stat:
with tf.control_dependencies([ema_apply_op]):
return tf.nn.batch_norm_with_global_normalization(
x, batch_mean, batch_var, beta, gamma, epsilon, True)
return tf.nn.batch_normalization(
x, batch_mean, batch_var, beta, gamma, epsilon, 'bn')
else:
batch = tf.cast(tf.shape(x)[0], tf.float32)
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
return tf.nn.batch_norm_with_global_normalization(
x, mean, var, beta, gamma, epsilon, True)
return tf.nn.batch_normalization(
x, mean, var, beta, gamma, epsilon, 'bn')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment