Commit 013565d6 authored by Yuxin Wu's avatar Yuxin Wu

Use tf.layers.BatchNormalization for implementation (#627)

parent edac0543
......@@ -24,9 +24,6 @@ To train Image-to-Image translation model with image pairs:
# you can download some data from the original authors:
# https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
Speed:
On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s)
Training visualization will appear be in tensorboard.
To visualize on test set:
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model
......@@ -71,7 +68,7 @@ class Model(GANModelDesc):
def generator(self, imgs):
# imgs: input: 256x256xch
# U-Net structure, it's slightly different from the original on the location of relu/lrelu
with argscope(BatchNorm, use_local_stat=True), \
with argscope(BatchNorm, training=True), \
argscope(Dropout, is_training=True):
# always use local stat for BN, and apply dropout even in testing
with argscope(Conv2D, kernel_size=4, strides=2, activation=BNLReLU):
......
......@@ -20,7 +20,6 @@ Reproduce the following GAN-related methods, 100~200 lines each:
+ CycleGAN ([Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593))
Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported.
## [DCGAN.py](DCGAN.py)
......
......@@ -56,11 +56,14 @@ def _inference_context():
class InferenceRunnerBase(Callback):
""" Base class for inference runner.
Please note that InferenceRunner will use `input.size()` to determine
Note:
1. InferenceRunner will use `input.size()` to determine
how much iterations to run, so you're responsible to ensure that
`size()` is accurate.
`size()` is reasonable.
Also, InferenceRunner assumes that `trainer.model` exists.
2. Only works with instances of `TowerTrainer`.
"""
def __init__(self, input, infs):
"""
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: _old_batch_norm.py
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from ..utils import logger
from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number
from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args
"""
Old Custom BN Implementation, Kept Here For Future Reference
"""
def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer())
else:
beta = tf.zeros([n_out], name='beta')
if use_scale:
gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
else:
gamma = tf.ones([n_out], name='gamma')
# x * gamma + beta
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(1.0), trainable=False)
return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay, internal_update):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
if internal_update:
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
else:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2)
return tf.identity(xn, name='output')
@layer_register()
@convert_to_tflayer_args(
args_names=[],
name_mapping={
'use_bias': 'center',
'use_scale': 'scale',
'gamma_init': 'gamma_initializer',
'decay': 'momentum',
'use_local_stat': 'training'
})
def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
center=True, scale=True,
gamma_initializer=tf.ones_initializer(),
data_format='channels_last',
internal_update=False):
"""
Mostly equivalent to `tf.layers.batch_normalization`, but difference in
the following:
1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored.
2. Default value for `momentum` and `epsilon` is different.
3. Default value for `training` is automatically obtained from `TowerContext`.
4. Support the `internal_update` option.
Args:
internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
by control dependencies.
Variable Names:
* ``beta``: the bias term. Will be zero-inited by default.
* ``gamma``: the scale term. Will be one-inited by default. Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean.
* ``variance/EMA``: the moving average of variance.
Note:
1. About multi-GPU training: moving averages across GPUs are not aggregated.
Batch statistics are computed independently. This is consistent with most frameworks.
2. Combinations of ``training`` and ``ctx.is_training``:
* ``training == ctx.is_training``: standard BN, EMA are
maintained during training and used during inference. This is
the default.
* ``training and not ctx.is_training``: still use batch statistics in inference.
* ``not training and ctx.is_training``: use EMA to normalize in
training. This is useful when you load a pre-trained BN and
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
data_format = get_data_format(data_format, tfmode=False)
shape = inputs.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
if ndims == 2:
data_format = 'NHWC'
if data_format == 'NCHW':
n_out = shape[1]
else:
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, scale, center, gamma_initializer)
ctx = get_current_tower_context()
use_local_stat = training
if use_local_stat is None:
use_local_stat = ctx.is_training
use_local_stat = bool(use_local_stat)
if use_local_stat:
if ndims == 2:
inputs = tf.reshape(inputs, [-1, 1, 1, n_out]) # fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
inputs, gamma, beta, epsilon=epsilon,
is_training=True, data_format=data_format)
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else:
if ctx.is_training:
assert get_tf_version_number() >= 1.4, \
"Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.is_main_training_tower: # only warn in first tower
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
xn, _, _ = tf.nn.fused_batch_norm(
inputs, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
if ndims == 4:
xn, _, _ = tf.nn.fused_batch_norm(
inputs, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
# avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization(
inputs, moving_mean, moving_var, beta, gamma, epsilon)
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
if ctx.is_main_training_tower:
add_model_variable(moving_mean)
add_model_variable(moving_var)
if ctx.is_main_training_tower and use_local_stat:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, momentum, internal_update)
else:
ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var)
if scale:
vh.gamma = gamma
if center:
vh.beta = beta
return ret
......@@ -5,7 +5,6 @@
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from ..utils import logger
from ..utils.argtools import get_data_format
......@@ -13,7 +12,7 @@ from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection
from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args
from .tflayer import convert_to_tflayer_args, rename_get_variable
__all__ = ['BatchNorm', 'BatchRenorm']
......@@ -21,51 +20,6 @@ __all__ = ['BatchNorm', 'BatchRenorm']
# eps: torch: 1e-5. Lasagne: 1e-4
def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer())
else:
beta = tf.zeros([n_out], name='beta')
if use_scale:
gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
else:
gamma = tf.ones([n_out], name='gamma')
# x * gamma + beta
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(1.0), trainable=False)
return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay, internal_update):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
if internal_update:
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
else:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2)
return tf.identity(xn, name='output')
def reshape_for_bn(param, ndims, chan, data_format):
if ndims == 2:
shape = [1, chan]
else:
shape = [1, 1, 1, chan] if data_format == 'NHWC' else [1, chan, 1, 1]
return tf.reshape(param, shape)
@layer_register()
@convert_to_tflayer_args(
args_names=[],
......@@ -82,7 +36,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
data_format='channels_last',
internal_update=False):
"""
Mostly equivalent to `tf.layers.batch_normalization`, but difference in
Mostly equivalent to `tf.layers.batch_normalization`, but different in
the following:
1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored.
......@@ -115,75 +69,68 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
# parse shapes
data_format = get_data_format(data_format, tfmode=False)
shape = inputs.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
assert ndims in [2, 4], ndims
if ndims == 2:
data_format = 'NHWC'
if data_format == 'NCHW':
n_out = shape[1]
axis = 1
else:
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, scale, center, gamma_initializer)
axis = 1 if data_format == 'NCHW' else 3
# parse training/ctx
ctx = get_current_tower_context()
use_local_stat = training
if use_local_stat is None:
use_local_stat = ctx.is_training
use_local_stat = bool(use_local_stat)
if use_local_stat:
if ndims == 2:
inputs = tf.reshape(inputs, [-1, 1, 1, n_out]) # fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
if training is None:
training = ctx.is_training
training = bool(training)
if not training and ctx.is_training:
assert get_tf_version_number() >= 1.4, \
"Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.is_main_training_tower: # only warn in first tower
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
inputs, gamma, beta, epsilon=epsilon,
is_training=True, data_format=data_format)
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else:
if ctx.is_training:
assert get_tf_version_number() >= 1.4, \
"Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.is_main_training_tower: # only warn in first tower
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
xn, _, _ = tf.nn.fused_batch_norm(
inputs, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
if ndims == 4:
xn, _, _ = tf.nn.fused_batch_norm(
inputs, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
# avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization(
inputs, moving_mean, moving_var, beta, gamma, epsilon)
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
with rename_get_variable(
{'moving_mean': 'mean/EMA',
'moving_variance': 'variance/EMA'}):
layer = tf.layers.BatchNormalization(
axis=axis,
momentum=momentum, epsilon=epsilon,
center=center, scale=scale,
gamma_initializer=gamma_initializer,
fused=True
)
xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
if ctx.is_main_training_tower:
add_model_variable(moving_mean)
add_model_variable(moving_var)
if ctx.is_main_training_tower and use_local_stat:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, momentum, internal_update)
for v in layer.non_trainable_variables:
add_model_variable(v)
if not ctx.is_main_training_tower or internal_update:
restore_collection(coll_bk)
if training and internal_update:
assert layer.updates
with tf.control_dependencies(layer.updates):
ret = tf.identity(xn, name='output')
else:
ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var)
vh = ret.variables = VariableHolder(
moving_mean=layer.moving_mean,
mean=layer.moving_mean, # for backward-compatibility
moving_variance=layer.moving_variance,
variance=layer.moving_variance) # for backward-compatibility
if scale:
vh.gamma = gamma
vh.gamma = layer.gamma
if center:
vh.beta = beta
vh.beta = layer.beta
return ret
......
......@@ -208,6 +208,8 @@ def add_moving_summary(*args, **kwargs):
collection (str or None): the name of the collection to add EMA-maintaining ops.
The default will work together with the default
:class:`MovingAverageSummary` callback.
summary_collections ([str]): the names of collections to add the
summary op. Default is TF's default (`tf.GraphKeys.SUMMARIES`).
Returns:
[tf.Tensor]: list of tensors returned by assign_moving_average,
......@@ -215,6 +217,7 @@ def add_moving_summary(*args, **kwargs):
"""
decay = kwargs.pop('decay', 0.95)
coll = kwargs.pop('collection', MOVING_SUMMARY_OPS_KEY)
summ_coll = kwargs.pop('summary_collections', None)
assert len(kwargs) == 0, "Unknown arguments: " + str(kwargs)
ctx = get_current_tower_context()
......@@ -248,7 +251,9 @@ def add_moving_summary(*args, **kwargs):
zero_debias=True, name=name + '_EMA_apply')
ema_ops.append(ema_op)
with tf.name_scope(None):
tf.summary.scalar(name + '-summary', ema_op) # write the EMA value as a summary
tf.summary.scalar(
name + '-summary', ema_op,
collections=summ_coll) # write the EMA value as a summary
if coll is not None:
for op in ema_ops:
tf.add_to_collection(coll, op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment