Commit e0a7e8f9 authored by Yuxin Wu's avatar Yuxin Wu

Better support for virtual_batch_size

parent 9f4600a0
......@@ -67,7 +67,7 @@ Dependencies:
+ Python 3.3+.
+ Python bindings for OpenCV. (Optional, but required by a lot of features)
+ TensorFlow ≥ 1.3, < 2. (Not required if you only want to use `tensorpack.dataflow` alone as a data processing library)
+ TensorFlow ≥ 1.5, < 2. (Not required if you only want to use `tensorpack.dataflow` alone as a data processing library)
```
pip install --upgrade git+https://github.com/tensorpack/tensorpack.git
# or add `--user` to install to user's local directories
......
......@@ -11,7 +11,6 @@ from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.argtools import get_data_format, log_once
from ..utils.develop import log_deprecated
from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable
from .utils import disable_autograph
......@@ -60,6 +59,59 @@ def internal_update_bn_ema(xn, batch_mean, batch_var,
return tf.identity(xn, name='output')
def get_sync_bn_mean_var(inputs, red_axis, sync_statistics):
ctx = get_current_tower_context()
batch_mean = tf.reduce_mean(inputs, axis=red_axis)
batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)
TF_version = get_tf_version_tuple()
if sync_statistics == 'nccl':
num_dev = ctx.total
if num_dev == 1:
logger.warn("BatchNorm(sync_statistics='nccl') is used with only one tower!")
else:
assert TF_version >= (1, 10), \
"Cross-GPU BatchNorm is only supported in TF>=1.10 ." \
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
if TF_version <= (1, 12):
try:
from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so # deprecated
except Exception:
pass
else:
_validate_and_load_nccl_so()
from tensorflow.contrib.nccl.ops import gen_nccl_ops # deprecated
else:
from tensorflow.python.ops import gen_nccl_ops
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
batch_mean = gen_nccl_ops.nccl_all_reduce(
input=batch_mean,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
batch_mean_square = gen_nccl_ops.nccl_all_reduce(
input=batch_mean_square,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
elif sync_statistics == 'horovod':
# Require https://github.com/uber/horovod/pull/331
import horovod.tensorflow as hvd
if hvd.size() == 1:
logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!")
else:
import horovod
hvd_version = tuple(map(int, horovod.__version__.split('.')[:3]))
assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean = hvd.allreduce(batch_mean, average=True)
batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
batch_var = batch_mean_square - tf.square(batch_mean)
return batch_mean, batch_var
@layer_register()
@convert_to_tflayer_args(
args_names=[],
......@@ -78,8 +130,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
virtual_batch_size=None,
data_format='channels_last',
ema_update='default',
sync_statistics=None,
internal_update=None):
sync_statistics=None):
"""
A more powerful version of `tf.layers.batch_normalization`. It differs from
the offical one in the following aspects:
......@@ -90,11 +141,19 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
User-provided value can overwrite this behavior.
4. Support the ``ema_update`` option, which covers broader use cases than the standard EMA update.
5. Support the ``sync_statistics`` option, which implements "SyncBN" and is very useful in small-batch models.
6. Better support of the ``virtual_batch_size`` option that does not have the bugs in ``tf.layers``.
Args:
training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA
to normalize. By default, it is equal to `get_current_tower_context().is_training`.
This is not a good argument name, but it is what the Tensorflow layer uses.
virtual_batch_size (int): implement "Ghost BatchNorm" that normalizes
the data with a smaller batch size than the input. Only effective when training is True.
The value has to be a divisor of the actual batch size.
It does not use the buggy TensorFlow implementation which has the
problems of (1) wrong behavior at inference; (2) create variables with unnecessary size=1 dimensions.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/23050
ema_update (str): Only effective when ``training=True``. It has the following options:
* "default": same as "collection". Because this is the default behavior in TensorFlow.
......@@ -128,7 +187,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
* "horovod": this layer must be used under tensorpack's :class:`HorovodTrainer`.
It uses the aggregated statistics of the whole batch (across all MPI ranks) to normalize.
Note that on single machine this is significantly slower than the "nccl" implementation.
Note that on a single machine this is found to be slower than the "nccl" implementation.
When not None, each GPU computes its own E[x] and E[x^2],
which are then averaged among all GPUs to compute global mean & variance.
......@@ -151,8 +210,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
When `sync_statistics` is enabled, `ema_update` is set to "internal" automatically.
This is to avoid running `UPDATE_OPS`, which requires synchronization.
internal_update: deprecated option. Don't use.
Variable Names:
* ``beta``: the bias term. Will be zero-inited by default.
......@@ -175,6 +232,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
if training is None:
training = ctx.is_training
training = bool(training)
if not training:
virtual_batch_size = None
# parse shapes
data_format = get_data_format(data_format, keras_mode=False)
......@@ -186,11 +245,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics
assert ema_update in ["default", "collection", "internal", "skip"]
if internal_update is not None:
log_deprecated("BatchNorm(internal_update=)", "Use ema_update='internal' instead!", "2020-01-01")
assert ema_update == 'default', \
"Do not use internal_update and ema_update together! internal_update is deprecated"
ema_update = "internal" if internal_update else "collection"
if ema_update == "default":
ema_update = "collection"
# Logic:
......@@ -211,12 +265,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
assert axis in [1, 3], axis
num_chan = shape[axis]
TF_version = get_tf_version_tuple()
freeze_bn_backward = not training and ctx.is_training
if freeze_bn_backward:
assert TF_version >= (1, 4), \
"Fine tuning a BatchNorm model with fixed statistics needs TF>=1.4!"
if ctx.is_main_training_tower: # only warn in first tower
log_once("Some BatchNorm layer uses moving_mean/moving_variance in training.", func='warn')
# Using moving_mean/moving_variance in training, which means we
......@@ -224,8 +274,9 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
do_sync_bn = (sync_statistics is not None) and training
if not do_sync_bn:
# Use the builtin layer for anything except for sync-bn
if not do_sync_bn and not virtual_batch_size:
# Use the builtin layer for regular per-GPU BN.
# Use our own implementation for SyncBN and GhostBN
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
with rename_get_variable(
{'moving_mean': 'mean/EMA',
......@@ -239,10 +290,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
# https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429
fused=(ndims == 4 and axis in [1, 3] and not freeze_bn_backward),
_reuse=tf.get_variable_scope().reuse)
if TF_version >= (1, 5):
tf_args['virtual_batch_size'] = virtual_batch_size
else:
assert virtual_batch_size is None, "Feature not supported in this version of TF!"
use_fp16 = inputs.dtype == tf.float16
if use_fp16:
# non-fused does not support fp16; fused does not support all layouts.
......@@ -279,65 +326,39 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
vh.beta = layer.beta
else:
red_axis = [0] if ndims == 2 else ([0, 2, 3] if axis == 1 else [0, 1, 2])
beta, gamma, moving_mean, moving_var = get_bn_variables(
num_chan, scale, center, beta_initializer, gamma_initializer)
assert sync_statistics is None or virtual_batch_size is None, "Cannot use SyncBN and GhostBN together!"
new_shape = None # don't need to reshape unless ...
if ndims == 4 and axis == 1:
new_shape = [1, num_chan, 1, 1]
batch_mean = tf.reduce_mean(inputs, axis=red_axis)
batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)
if sync_statistics is not None:
# sync bn
batch_mean, batch_var = get_sync_bn_mean_var(inputs, red_axis)
batch_mean_vec = batch_mean
batch_var_vec = batch_var
if sync_statistics == 'nccl':
num_dev = ctx.total
if num_dev == 1:
logger.warn("BatchNorm(sync_statistics='nccl') is used with only one tower!")
else:
assert TF_version >= (1, 10), \
"Cross-GPU BatchNorm is only supported in TF>=1.10 ." \
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
if TF_version <= (1, 12):
try:
from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so # deprecated
except Exception:
pass
else:
_validate_and_load_nccl_so()
from tensorflow.contrib.nccl.ops import gen_nccl_ops # deprecated
else:
from tensorflow.python.ops import gen_nccl_ops
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
batch_mean = gen_nccl_ops.nccl_all_reduce(
input=batch_mean,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
batch_mean_square = gen_nccl_ops.nccl_all_reduce(
input=batch_mean_square,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
elif sync_statistics == 'horovod':
# Require https://github.com/uber/horovod/pull/331
import horovod.tensorflow as hvd
if hvd.size() == 1:
logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!")
else:
import horovod
hvd_version = tuple(map(int, horovod.__version__.split('.')[:3]))
assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean = hvd.allreduce(batch_mean, average=True)
batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
batch_var = batch_mean_square - tf.square(batch_mean)
batch_mean_vec = batch_mean
batch_var_vec = batch_var
if ndims == 4 and axis == 1:
new_shape = [1, num_chan, 1, 1]
batch_mean = tf.reshape(batch_mean, new_shape)
batch_var = tf.reshape(batch_var, new_shape)
else:
orig_shape = tf.shape(inputs)
inputs = tf.reshape(
inputs,
tf.concat([[-1, virtual_batch_size],
tf.shape(inputs)[1:]], axis=0))
# B/V, V, ...
red_axis = [x + 1 for x in red_axis]
new_shape = [1] * (ndims + 1)
new_shape[axis + 1] = num_chan
batch_mean, batch_var = tf.nn.moments(inputs, red_axis, keepdims=True)
# B/V, C
# vec for EMA update: use the first one only to mimic per-GPU BN
batch_mean_vec = tf.reshape(batch_mean[0], [num_chan])
batch_var_vec = tf.reshape(batch_var[0], [num_chan])
beta, gamma, moving_mean, moving_var = get_bn_variables(
num_chan, scale, center, beta_initializer, gamma_initializer)
if new_shape is not None:
batch_mean = tf.reshape(batch_mean, new_shape)
batch_var = tf.reshape(batch_var, new_shape)
# Using fused_batch_norm(is_training=False) is actually slightly faster,
# but hopefully this call will be JITed in the future.
xn = tf.nn.batch_normalization(
......@@ -348,6 +369,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
xn = tf.nn.batch_normalization(
inputs, batch_mean, batch_var,
beta, gamma, epsilon)
if virtual_batch_size is not None:
xn = tf.reshape(xn, orig_shape)
if do_ema_update:
ret = internal_update_bn_ema(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment