Commit f1d15364 authored by Yuxin Wu's avatar Yuxin Wu

add BNV2 which uses fused_batch_norm

parent 78ccd295
...@@ -34,4 +34,5 @@ It requires the datasets released by the original authors. ...@@ -34,4 +34,5 @@ It requires the datasets released by the original authors.
Reproduce a mnist experiement in InfoGAN. Reproduce a mnist experiement in InfoGAN.
By assuming 10 latent variables corresponding to a categorical distribution and maximizing mutual information, By assuming 10 latent variables corresponding to a categorical distribution and maximizing mutual information,
the network learns to map the 10 variables to 10 digits in a completely unsupervised way. the network learns to map the 10 variables to 10 digits in a completely unsupervised way.
![infogan](demo/InfoGAN-mnist.jpg) ![infogan](demo/InfoGAN-mnist.jpg)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from copy import copy from copy import copy
import re import re
...@@ -12,13 +13,12 @@ from ..tfutils.tower import get_current_tower_context ...@@ -12,13 +13,12 @@ from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ._common import layer_register from ._common import layer_register
__all__ = ['BatchNorm'] __all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
# TF batch_norm only works for 4D tensor right now: #804
# decay: being too close to 1 leads to slow start-up. torch use 0.9. # decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4 # eps: torch: 1e-5. Lasagne: 1e-4
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
""" """
Batch normalization layer as described in: Batch normalization layer as described in:
...@@ -107,3 +107,74 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -107,3 +107,74 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
else: else:
return tf.nn.batch_normalization( return tf.nn.batch_normalization(
x, ema_mean, ema_var, beta, gamma, epsilon, 'output') x, ema_mean, ema_var, beta, gamma, epsilon, 'output')
@layer_register(log_shape=False)
def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
"""
Batch normalization layer as described in:
`Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_.
:param input: a NHWC or NC tensor
:param use_local_stat: bool. whether to use mean/var of this batch or the moving average.
Default to True in training and False in inference.
:param decay: decay rate. default to 0.9.
:param epsilon: default to 1e-5.
"""
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
beta = tf.get_variable('beta', [n_out],
initializer=tf.zeros_initializer)
gamma = tf.get_variable('gamma', [n_out],
initializer=tf.constant_initializer(1.0))
# x * gamma + beta
ctx = get_current_tower_context()
if use_local_stat is None:
use_local_stat = ctx.is_training
if use_local_stat != ctx.is_training:
logger.warn("[BatchNorm] use_local_stat != is_training")
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.zeros_initializer, trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.zeros_initializer, trainable=False)
if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=ctx.is_training)
if ctx.is_training:
# maintain EMA if training
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
if ctx.is_main_training_tower:
add_model_variable(moving_mean)
add_model_variable(moving_var)
else:
assert not ctx.is_training, "In training, local statistics has to be used!"
# TODO do I need to add_model_variable.
# assume some fixed-param tasks, such as load model and fine tune one layer
# fused is slower in inference
#xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#moving_mean, moving_var,
#epsilon=epsilon, is_training=False, name='output')
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
if ctx.is_training:
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
else:
return tf.identity(xn, name='output')
BatchNorm = BatchNormV2
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment