Commit 013565d6 authored by Yuxin Wu's avatar Yuxin Wu

Use tf.layers.BatchNormalization for implementation (#627)

parent edac0543
...@@ -24,9 +24,6 @@ To train Image-to-Image translation model with image pairs: ...@@ -24,9 +24,6 @@ To train Image-to-Image translation model with image pairs:
# you can download some data from the original authors: # you can download some data from the original authors:
# https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/ # https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
Speed:
On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s)
Training visualization will appear be in tensorboard. Training visualization will appear be in tensorboard.
To visualize on test set: To visualize on test set:
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model ./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model
...@@ -71,7 +68,7 @@ class Model(GANModelDesc): ...@@ -71,7 +68,7 @@ class Model(GANModelDesc):
def generator(self, imgs): def generator(self, imgs):
# imgs: input: 256x256xch # imgs: input: 256x256xch
# U-Net structure, it's slightly different from the original on the location of relu/lrelu # U-Net structure, it's slightly different from the original on the location of relu/lrelu
with argscope(BatchNorm, use_local_stat=True), \ with argscope(BatchNorm, training=True), \
argscope(Dropout, is_training=True): argscope(Dropout, is_training=True):
# always use local stat for BN, and apply dropout even in testing # always use local stat for BN, and apply dropout even in testing
with argscope(Conv2D, kernel_size=4, strides=2, activation=BNLReLU): with argscope(Conv2D, kernel_size=4, strides=2, activation=BNLReLU):
......
...@@ -20,7 +20,6 @@ Reproduce the following GAN-related methods, 100~200 lines each: ...@@ -20,7 +20,6 @@ Reproduce the following GAN-related methods, 100~200 lines each:
+ CycleGAN ([Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)) + CycleGAN ([Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593))
Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported. Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported.
## [DCGAN.py](DCGAN.py) ## [DCGAN.py](DCGAN.py)
......
...@@ -56,11 +56,14 @@ def _inference_context(): ...@@ -56,11 +56,14 @@ def _inference_context():
class InferenceRunnerBase(Callback): class InferenceRunnerBase(Callback):
""" Base class for inference runner. """ Base class for inference runner.
Please note that InferenceRunner will use `input.size()` to determine
Note:
1. InferenceRunner will use `input.size()` to determine
how much iterations to run, so you're responsible to ensure that how much iterations to run, so you're responsible to ensure that
`size()` is accurate. `size()` is reasonable.
Also, InferenceRunner assumes that `trainer.model` exists. 2. Only works with instances of `TowerTrainer`.
""" """
def __init__(self, input, infs): def __init__(self, input, infs):
""" """
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: _old_batch_norm.py
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from ..utils import logger
from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number
from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args
"""
Old Custom BN Implementation, Kept Here For Future Reference
"""
def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer())
else:
beta = tf.zeros([n_out], name='beta')
if use_scale:
gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
else:
gamma = tf.ones([n_out], name='gamma')
# x * gamma + beta
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(1.0), trainable=False)
return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay, internal_update):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
if internal_update:
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
else:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2)
return tf.identity(xn, name='output')
@layer_register()
@convert_to_tflayer_args(
args_names=[],
name_mapping={
'use_bias': 'center',
'use_scale': 'scale',
'gamma_init': 'gamma_initializer',
'decay': 'momentum',
'use_local_stat': 'training'
})
def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
center=True, scale=True,
gamma_initializer=tf.ones_initializer(),
data_format='channels_last',
internal_update=False):
"""
Mostly equivalent to `tf.layers.batch_normalization`, but difference in
the following:
1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored.
2. Default value for `momentum` and `epsilon` is different.
3. Default value for `training` is automatically obtained from `TowerContext`.
4. Support the `internal_update` option.
Args:
internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
by control dependencies.
Variable Names:
* ``beta``: the bias term. Will be zero-inited by default.
* ``gamma``: the scale term. Will be one-inited by default. Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean.
* ``variance/EMA``: the moving average of variance.
Note:
1. About multi-GPU training: moving averages across GPUs are not aggregated.
Batch statistics are computed independently. This is consistent with most frameworks.
2. Combinations of ``training`` and ``ctx.is_training``:
* ``training == ctx.is_training``: standard BN, EMA are
maintained during training and used during inference. This is
the default.
* ``training and not ctx.is_training``: still use batch statistics in inference.
* ``not training and ctx.is_training``: use EMA to normalize in
training. This is useful when you load a pre-trained BN and
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
data_format = get_data_format(data_format, tfmode=False)
shape = inputs.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
if ndims == 2:
data_format = 'NHWC'
if data_format == 'NCHW':
n_out = shape[1]
else:
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, scale, center, gamma_initializer)
ctx = get_current_tower_context()
use_local_stat = training
if use_local_stat is None:
use_local_stat = ctx.is_training
use_local_stat = bool(use_local_stat)
if use_local_stat:
if ndims == 2:
inputs = tf.reshape(inputs, [-1, 1, 1, n_out]) # fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
inputs, gamma, beta, epsilon=epsilon,
is_training=True, data_format=data_format)
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else:
if ctx.is_training:
assert get_tf_version_number() >= 1.4, \
"Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.is_main_training_tower: # only warn in first tower
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
xn, _, _ = tf.nn.fused_batch_norm(
inputs, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
if ndims == 4:
xn, _, _ = tf.nn.fused_batch_norm(
inputs, gamma, beta,
mean=moving_mean, variance=moving_var, epsilon=epsilon,
data_format=data_format, is_training=False)
else:
# avoid the reshape if possible (when channel is the last dimension)
xn = tf.nn.batch_normalization(
inputs, moving_mean, moving_var, beta, gamma, epsilon)
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
if ctx.is_main_training_tower:
add_model_variable(moving_mean)
add_model_variable(moving_var)
if ctx.is_main_training_tower and use_local_stat:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, momentum, internal_update)
else:
ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var)
if scale:
vh.gamma = gamma
if center:
vh.beta = beta
return ret
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from ..utils import logger from ..utils import logger
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
...@@ -13,7 +12,7 @@ from ..tfutils.tower import get_current_tower_context ...@@ -13,7 +12,7 @@ from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from .common import layer_register, VariableHolder from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args from .tflayer import convert_to_tflayer_args, rename_get_variable
__all__ = ['BatchNorm', 'BatchRenorm'] __all__ = ['BatchNorm', 'BatchRenorm']
...@@ -21,51 +20,6 @@ __all__ = ['BatchNorm', 'BatchRenorm'] ...@@ -21,51 +20,6 @@ __all__ = ['BatchNorm', 'BatchRenorm']
# eps: torch: 1e-5. Lasagne: 1e-4 # eps: torch: 1e-5. Lasagne: 1e-4
def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer())
else:
beta = tf.zeros([n_out], name='beta')
if use_scale:
gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
else:
gamma = tf.ones([n_out], name='gamma')
# x * gamma + beta
moving_mean = tf.get_variable('mean/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(1.0), trainable=False)
return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var,
moving_mean, moving_var, decay, internal_update):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
if internal_update:
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
else:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2)
return tf.identity(xn, name='output')
def reshape_for_bn(param, ndims, chan, data_format):
if ndims == 2:
shape = [1, chan]
else:
shape = [1, 1, 1, chan] if data_format == 'NHWC' else [1, chan, 1, 1]
return tf.reshape(param, shape)
@layer_register() @layer_register()
@convert_to_tflayer_args( @convert_to_tflayer_args(
args_names=[], args_names=[],
...@@ -82,7 +36,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, ...@@ -82,7 +36,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
data_format='channels_last', data_format='channels_last',
internal_update=False): internal_update=False):
""" """
Mostly equivalent to `tf.layers.batch_normalization`, but difference in Mostly equivalent to `tf.layers.batch_normalization`, but different in
the following: the following:
1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored. 1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored.
...@@ -115,38 +69,23 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, ...@@ -115,38 +69,23 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
don't want to fine tune the EMA. EMA will not be updated in don't want to fine tune the EMA. EMA will not be updated in
this case. this case.
""" """
# parse shapes
data_format = get_data_format(data_format, tfmode=False) data_format = get_data_format(data_format, tfmode=False)
shape = inputs.get_shape().as_list() shape = inputs.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
assert ndims in [2, 4] assert ndims in [2, 4], ndims
if ndims == 2: if ndims == 2:
data_format = 'NHWC' data_format = 'NHWC'
if data_format == 'NCHW': axis = 1
n_out = shape[1]
else: else:
n_out = shape[-1] # channel axis = 1 if data_format == 'NCHW' else 3
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, scale, center, gamma_initializer)
# parse training/ctx
ctx = get_current_tower_context() ctx = get_current_tower_context()
use_local_stat = training if training is None:
if use_local_stat is None: training = ctx.is_training
use_local_stat = ctx.is_training training = bool(training)
use_local_stat = bool(use_local_stat) if not training and ctx.is_training:
if use_local_stat:
if ndims == 2:
inputs = tf.reshape(inputs, [-1, 1, 1, n_out]) # fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
inputs, gamma, beta, epsilon=epsilon,
is_training=True, data_format=data_format)
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else:
if ctx.is_training:
assert get_tf_version_number() >= 1.4, \ assert get_tf_version_number() >= 1.4, \
"Fine tuning a BatchNorm model with fixed statistics is only " \ "Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 " "supported after https://github.com/tensorflow/tensorflow/pull/12580 "
...@@ -154,36 +93,44 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, ...@@ -154,36 +93,44 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.") logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we # Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part. # loaded a pre-trained BN and only fine-tuning the affine part.
xn, _, _ = tf.nn.fused_batch_norm(
inputs, gamma, beta, coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
mean=moving_mean, variance=moving_var, epsilon=epsilon, with rename_get_variable(
data_format=data_format, is_training=False) {'moving_mean': 'mean/EMA',
else: 'moving_variance': 'variance/EMA'}):
if ndims == 4: layer = tf.layers.BatchNormalization(
xn, _, _ = tf.nn.fused_batch_norm( axis=axis,
inputs, gamma, beta, momentum=momentum, epsilon=epsilon,
mean=moving_mean, variance=moving_var, epsilon=epsilon, center=center, scale=scale,
data_format=data_format, is_training=False) gamma_initializer=gamma_initializer,
else: fused=True
# avoid the reshape if possible (when channel is the last dimension) )
xn = tf.nn.batch_normalization( xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())
inputs, moving_mean, moving_var, beta, gamma, epsilon)
# maintain EMA only on one GPU is OK, even in replicated mode. # maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA # because training time doesn't use EMA
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
add_model_variable(moving_mean) for v in layer.non_trainable_variables:
add_model_variable(moving_var) add_model_variable(v)
if ctx.is_main_training_tower and use_local_stat: if not ctx.is_main_training_tower or internal_update:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, momentum, internal_update) restore_collection(coll_bk)
if training and internal_update:
assert layer.updates
with tf.control_dependencies(layer.updates):
ret = tf.identity(xn, name='output')
else: else:
ret = tf.identity(xn, name='output') ret = tf.identity(xn, name='output')
vh = ret.variables = VariableHolder(mean=moving_mean, variance=moving_var) vh = ret.variables = VariableHolder(
moving_mean=layer.moving_mean,
mean=layer.moving_mean, # for backward-compatibility
moving_variance=layer.moving_variance,
variance=layer.moving_variance) # for backward-compatibility
if scale: if scale:
vh.gamma = gamma vh.gamma = layer.gamma
if center: if center:
vh.beta = beta vh.beta = layer.beta
return ret return ret
......
...@@ -208,6 +208,8 @@ def add_moving_summary(*args, **kwargs): ...@@ -208,6 +208,8 @@ def add_moving_summary(*args, **kwargs):
collection (str or None): the name of the collection to add EMA-maintaining ops. collection (str or None): the name of the collection to add EMA-maintaining ops.
The default will work together with the default The default will work together with the default
:class:`MovingAverageSummary` callback. :class:`MovingAverageSummary` callback.
summary_collections ([str]): the names of collections to add the
summary op. Default is TF's default (`tf.GraphKeys.SUMMARIES`).
Returns: Returns:
[tf.Tensor]: list of tensors returned by assign_moving_average, [tf.Tensor]: list of tensors returned by assign_moving_average,
...@@ -215,6 +217,7 @@ def add_moving_summary(*args, **kwargs): ...@@ -215,6 +217,7 @@ def add_moving_summary(*args, **kwargs):
""" """
decay = kwargs.pop('decay', 0.95) decay = kwargs.pop('decay', 0.95)
coll = kwargs.pop('collection', MOVING_SUMMARY_OPS_KEY) coll = kwargs.pop('collection', MOVING_SUMMARY_OPS_KEY)
summ_coll = kwargs.pop('summary_collections', None)
assert len(kwargs) == 0, "Unknown arguments: " + str(kwargs) assert len(kwargs) == 0, "Unknown arguments: " + str(kwargs)
ctx = get_current_tower_context() ctx = get_current_tower_context()
...@@ -248,7 +251,9 @@ def add_moving_summary(*args, **kwargs): ...@@ -248,7 +251,9 @@ def add_moving_summary(*args, **kwargs):
zero_debias=True, name=name + '_EMA_apply') zero_debias=True, name=name + '_EMA_apply')
ema_ops.append(ema_op) ema_ops.append(ema_op)
with tf.name_scope(None): with tf.name_scope(None):
tf.summary.scalar(name + '-summary', ema_op) # write the EMA value as a summary tf.summary.scalar(
name + '-summary', ema_op,
collections=summ_coll) # write the EMA value as a summary
if coll is not None: if coll is not None:
for op in ema_ops: for op in ema_ops:
tf.add_to_collection(coll, op) tf.add_to_collection(coll, op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment