Commit c759f211 authored by Yuxin Wu's avatar Yuxin Wu

use official batchnorm op

parent df89a95f
...@@ -19,8 +19,7 @@ from tensorpack.dataflow import * ...@@ -19,8 +19,7 @@ from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
""" """
CIFAR10 90% validation accuracy after 100k step. CIFAR10 90% validation accuracy after 70k step.
91% after 160k step
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -43,6 +42,7 @@ class Model(ModelDesc): ...@@ -43,6 +42,7 @@ class Model(ModelDesc):
num_threads=6, enqueue_many=True) num_threads=6, enqueue_many=True)
tf.image_summary("train_image", image, 10) tf.image_summary("train_image", image, 10)
image = image / 4.0 # just to make range smaller
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3) l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=tf.identity) l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=tf.identity)
l = BatchNorm('bn1', l, is_training) l = BatchNorm('bn1', l, is_training)
...@@ -112,7 +112,7 @@ def get_data(train_or_test): ...@@ -112,7 +112,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain) ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchData(ds, 3, 2) ds = PrefetchData(ds, 10, 5)
return ds return ds
...@@ -120,7 +120,7 @@ def get_data(train_or_test): ...@@ -120,7 +120,7 @@ def get_data(train_or_test):
def get_config(): def get_config():
# prepare dataset # prepare dataset
dataset_train = get_data('train') dataset_train = get_data('train')
step_per_epoch = dataset_train.size() / 2 step_per_epoch = dataset_train.size()
dataset_test = get_data('test') dataset_test = get_data('test')
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -31,10 +31,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -31,10 +31,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
if len(shape) == 2: assert len(shape) in [2, 4]
x = tf.reshape(x, [-1, 1, 1, shape[1]])
shape = x.get_shape().as_list()
assert len(shape) == 4
n_out = shape[-1] # channel n_out = shape[-1] # channel
beta = tf.get_variable('beta', [n_out]) beta = tf.get_variable('beta', [n_out])
...@@ -42,7 +39,10 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -42,7 +39,10 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
'gamma', [n_out], 'gamma', [n_out],
initializer=tf.constant_initializer(1.0)) initializer=tf.constant_initializer(1.0))
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], name='moments', keep_dims=False)
else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments', keep_dims=False)
ema = tf.train.ExponentialMovingAverage(decay=decay) ema = tf.train.ExponentialMovingAverage(decay=decay)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
...@@ -50,10 +50,10 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -50,10 +50,10 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
if use_local_stat: if use_local_stat:
with tf.control_dependencies([ema_apply_op]): with tf.control_dependencies([ema_apply_op]):
return tf.nn.batch_norm_with_global_normalization( return tf.nn.batch_normalization(
x, batch_mean, batch_var, beta, gamma, epsilon, True) x, batch_mean, batch_var, beta, gamma, epsilon, 'bn')
else: else:
batch = tf.cast(tf.shape(x)[0], tf.float32) batch = tf.cast(tf.shape(x)[0], tf.float32)
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
return tf.nn.batch_norm_with_global_normalization( return tf.nn.batch_normalization(
x, mean, var, beta, gamma, epsilon, True) x, mean, var, beta, gamma, epsilon, 'bn')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment