Commit e0a7e8f9 authored by Yuxin Wu's avatar Yuxin Wu

Better support for virtual_batch_size

parent 9f4600a0
...@@ -67,7 +67,7 @@ Dependencies: ...@@ -67,7 +67,7 @@ Dependencies:
+ Python 3.3+. + Python 3.3+.
+ Python bindings for OpenCV. (Optional, but required by a lot of features) + Python bindings for OpenCV. (Optional, but required by a lot of features)
+ TensorFlow ≥ 1.3, < 2. (Not required if you only want to use `tensorpack.dataflow` alone as a data processing library) + TensorFlow ≥ 1.5, < 2. (Not required if you only want to use `tensorpack.dataflow` alone as a data processing library)
``` ```
pip install --upgrade git+https://github.com/tensorpack/tensorpack.git pip install --upgrade git+https://github.com/tensorpack/tensorpack.git
# or add `--user` to install to user's local directories # or add `--user` to install to user's local directories
......
...@@ -11,7 +11,6 @@ from ..tfutils.common import get_tf_version_tuple ...@@ -11,7 +11,6 @@ from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.argtools import get_data_format, log_once from ..utils.argtools import get_data_format, log_once
from ..utils.develop import log_deprecated
from .common import VariableHolder, layer_register from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable from .tflayer import convert_to_tflayer_args, rename_get_variable
from .utils import disable_autograph from .utils import disable_autograph
...@@ -60,6 +59,59 @@ def internal_update_bn_ema(xn, batch_mean, batch_var, ...@@ -60,6 +59,59 @@ def internal_update_bn_ema(xn, batch_mean, batch_var,
return tf.identity(xn, name='output') return tf.identity(xn, name='output')
def get_sync_bn_mean_var(inputs, red_axis, sync_statistics):
ctx = get_current_tower_context()
batch_mean = tf.reduce_mean(inputs, axis=red_axis)
batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)
TF_version = get_tf_version_tuple()
if sync_statistics == 'nccl':
num_dev = ctx.total
if num_dev == 1:
logger.warn("BatchNorm(sync_statistics='nccl') is used with only one tower!")
else:
assert TF_version >= (1, 10), \
"Cross-GPU BatchNorm is only supported in TF>=1.10 ." \
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
if TF_version <= (1, 12):
try:
from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so # deprecated
except Exception:
pass
else:
_validate_and_load_nccl_so()
from tensorflow.contrib.nccl.ops import gen_nccl_ops # deprecated
else:
from tensorflow.python.ops import gen_nccl_ops
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
batch_mean = gen_nccl_ops.nccl_all_reduce(
input=batch_mean,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
batch_mean_square = gen_nccl_ops.nccl_all_reduce(
input=batch_mean_square,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
elif sync_statistics == 'horovod':
# Require https://github.com/uber/horovod/pull/331
import horovod.tensorflow as hvd
if hvd.size() == 1:
logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!")
else:
import horovod
hvd_version = tuple(map(int, horovod.__version__.split('.')[:3]))
assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean = hvd.allreduce(batch_mean, average=True)
batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
batch_var = batch_mean_square - tf.square(batch_mean)
return batch_mean, batch_var
@layer_register() @layer_register()
@convert_to_tflayer_args( @convert_to_tflayer_args(
args_names=[], args_names=[],
...@@ -78,8 +130,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -78,8 +130,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
virtual_batch_size=None, virtual_batch_size=None,
data_format='channels_last', data_format='channels_last',
ema_update='default', ema_update='default',
sync_statistics=None, sync_statistics=None):
internal_update=None):
""" """
A more powerful version of `tf.layers.batch_normalization`. It differs from A more powerful version of `tf.layers.batch_normalization`. It differs from
the offical one in the following aspects: the offical one in the following aspects:
...@@ -90,11 +141,19 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -90,11 +141,19 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
User-provided value can overwrite this behavior. User-provided value can overwrite this behavior.
4. Support the ``ema_update`` option, which covers broader use cases than the standard EMA update. 4. Support the ``ema_update`` option, which covers broader use cases than the standard EMA update.
5. Support the ``sync_statistics`` option, which implements "SyncBN" and is very useful in small-batch models. 5. Support the ``sync_statistics`` option, which implements "SyncBN" and is very useful in small-batch models.
6. Better support of the ``virtual_batch_size`` option that does not have the bugs in ``tf.layers``.
Args: Args:
training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA
to normalize. By default, it is equal to `get_current_tower_context().is_training`. to normalize. By default, it is equal to `get_current_tower_context().is_training`.
This is not a good argument name, but it is what the Tensorflow layer uses. This is not a good argument name, but it is what the Tensorflow layer uses.
virtual_batch_size (int): implement "Ghost BatchNorm" that normalizes
the data with a smaller batch size than the input. Only effective when training is True.
The value has to be a divisor of the actual batch size.
It does not use the buggy TensorFlow implementation which has the
problems of (1) wrong behavior at inference; (2) create variables with unnecessary size=1 dimensions.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/23050
ema_update (str): Only effective when ``training=True``. It has the following options: ema_update (str): Only effective when ``training=True``. It has the following options:
* "default": same as "collection". Because this is the default behavior in TensorFlow. * "default": same as "collection". Because this is the default behavior in TensorFlow.
...@@ -128,7 +187,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -128,7 +187,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
* "horovod": this layer must be used under tensorpack's :class:`HorovodTrainer`. * "horovod": this layer must be used under tensorpack's :class:`HorovodTrainer`.
It uses the aggregated statistics of the whole batch (across all MPI ranks) to normalize. It uses the aggregated statistics of the whole batch (across all MPI ranks) to normalize.
Note that on single machine this is significantly slower than the "nccl" implementation. Note that on a single machine this is found to be slower than the "nccl" implementation.
When not None, each GPU computes its own E[x] and E[x^2], When not None, each GPU computes its own E[x] and E[x^2],
which are then averaged among all GPUs to compute global mean & variance. which are then averaged among all GPUs to compute global mean & variance.
...@@ -151,8 +210,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -151,8 +210,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
When `sync_statistics` is enabled, `ema_update` is set to "internal" automatically. When `sync_statistics` is enabled, `ema_update` is set to "internal" automatically.
This is to avoid running `UPDATE_OPS`, which requires synchronization. This is to avoid running `UPDATE_OPS`, which requires synchronization.
internal_update: deprecated option. Don't use.
Variable Names: Variable Names:
* ``beta``: the bias term. Will be zero-inited by default. * ``beta``: the bias term. Will be zero-inited by default.
...@@ -175,6 +232,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -175,6 +232,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
if training is None: if training is None:
training = ctx.is_training training = ctx.is_training
training = bool(training) training = bool(training)
if not training:
virtual_batch_size = None
# parse shapes # parse shapes
data_format = get_data_format(data_format, keras_mode=False) data_format = get_data_format(data_format, keras_mode=False)
...@@ -186,11 +245,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -186,11 +245,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics
assert ema_update in ["default", "collection", "internal", "skip"] assert ema_update in ["default", "collection", "internal", "skip"]
if internal_update is not None:
log_deprecated("BatchNorm(internal_update=)", "Use ema_update='internal' instead!", "2020-01-01")
assert ema_update == 'default', \
"Do not use internal_update and ema_update together! internal_update is deprecated"
ema_update = "internal" if internal_update else "collection"
if ema_update == "default": if ema_update == "default":
ema_update = "collection" ema_update = "collection"
# Logic: # Logic:
...@@ -211,12 +265,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -211,12 +265,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
assert axis in [1, 3], axis assert axis in [1, 3], axis
num_chan = shape[axis] num_chan = shape[axis]
TF_version = get_tf_version_tuple()
freeze_bn_backward = not training and ctx.is_training freeze_bn_backward = not training and ctx.is_training
if freeze_bn_backward: if freeze_bn_backward:
assert TF_version >= (1, 4), \
"Fine tuning a BatchNorm model with fixed statistics needs TF>=1.4!"
if ctx.is_main_training_tower: # only warn in first tower if ctx.is_main_training_tower: # only warn in first tower
log_once("Some BatchNorm layer uses moving_mean/moving_variance in training.", func='warn') log_once("Some BatchNorm layer uses moving_mean/moving_variance in training.", func='warn')
# Using moving_mean/moving_variance in training, which means we # Using moving_mean/moving_variance in training, which means we
...@@ -224,8 +274,9 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -224,8 +274,9 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
do_sync_bn = (sync_statistics is not None) and training do_sync_bn = (sync_statistics is not None) and training
if not do_sync_bn: if not do_sync_bn and not virtual_batch_size:
# Use the builtin layer for anything except for sync-bn # Use the builtin layer for regular per-GPU BN.
# Use our own implementation for SyncBN and GhostBN
coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS]) coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
with rename_get_variable( with rename_get_variable(
{'moving_mean': 'mean/EMA', {'moving_mean': 'mean/EMA',
...@@ -239,10 +290,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -239,10 +290,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
# https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429 # https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429
fused=(ndims == 4 and axis in [1, 3] and not freeze_bn_backward), fused=(ndims == 4 and axis in [1, 3] and not freeze_bn_backward),
_reuse=tf.get_variable_scope().reuse) _reuse=tf.get_variable_scope().reuse)
if TF_version >= (1, 5):
tf_args['virtual_batch_size'] = virtual_batch_size
else:
assert virtual_batch_size is None, "Feature not supported in this version of TF!"
use_fp16 = inputs.dtype == tf.float16 use_fp16 = inputs.dtype == tf.float16
if use_fp16: if use_fp16:
# non-fused does not support fp16; fused does not support all layouts. # non-fused does not support fp16; fused does not support all layouts.
...@@ -279,65 +326,39 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -279,65 +326,39 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
vh.beta = layer.beta vh.beta = layer.beta
else: else:
red_axis = [0] if ndims == 2 else ([0, 2, 3] if axis == 1 else [0, 1, 2]) red_axis = [0] if ndims == 2 else ([0, 2, 3] if axis == 1 else [0, 1, 2])
beta, gamma, moving_mean, moving_var = get_bn_variables(
num_chan, scale, center, beta_initializer, gamma_initializer)
assert sync_statistics is None or virtual_batch_size is None, "Cannot use SyncBN and GhostBN together!"
new_shape = None # don't need to reshape unless ... new_shape = None # don't need to reshape unless ...
if ndims == 4 and axis == 1:
new_shape = [1, num_chan, 1, 1]
batch_mean = tf.reduce_mean(inputs, axis=red_axis) if sync_statistics is not None:
batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis) # sync bn
batch_mean, batch_var = get_sync_bn_mean_var(inputs, red_axis)
batch_mean_vec = batch_mean
batch_var_vec = batch_var
if sync_statistics == 'nccl': if ndims == 4 and axis == 1:
num_dev = ctx.total new_shape = [1, num_chan, 1, 1]
if num_dev == 1: batch_mean = tf.reshape(batch_mean, new_shape)
logger.warn("BatchNorm(sync_statistics='nccl') is used with only one tower!") batch_var = tf.reshape(batch_var, new_shape)
else: else:
assert TF_version >= (1, 10), \ orig_shape = tf.shape(inputs)
"Cross-GPU BatchNorm is only supported in TF>=1.10 ." \ inputs = tf.reshape(
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360" inputs,
tf.concat([[-1, virtual_batch_size],
if TF_version <= (1, 12): tf.shape(inputs)[1:]], axis=0))
try: # B/V, V, ...
from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so # deprecated red_axis = [x + 1 for x in red_axis]
except Exception: new_shape = [1] * (ndims + 1)
pass new_shape[axis + 1] = num_chan
else:
_validate_and_load_nccl_so() batch_mean, batch_var = tf.nn.moments(inputs, red_axis, keepdims=True)
from tensorflow.contrib.nccl.ops import gen_nccl_ops # deprecated # B/V, C
else: # vec for EMA update: use the first one only to mimic per-GPU BN
from tensorflow.python.ops import gen_nccl_ops batch_mean_vec = tf.reshape(batch_mean[0], [num_chan])
shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name) batch_var_vec = tf.reshape(batch_var[0], [num_chan])
batch_mean = gen_nccl_ops.nccl_all_reduce(
input=batch_mean,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
batch_mean_square = gen_nccl_ops.nccl_all_reduce(
input=batch_mean_square,
reduction='sum',
num_devices=num_dev,
shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
elif sync_statistics == 'horovod':
# Require https://github.com/uber/horovod/pull/331
import horovod.tensorflow as hvd
if hvd.size() == 1:
logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!")
else:
import horovod
hvd_version = tuple(map(int, horovod.__version__.split('.')[:3]))
assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean = hvd.allreduce(batch_mean, average=True)
batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
batch_var = batch_mean_square - tf.square(batch_mean)
batch_mean_vec = batch_mean
batch_var_vec = batch_var
beta, gamma, moving_mean, moving_var = get_bn_variables(
num_chan, scale, center, beta_initializer, gamma_initializer)
if new_shape is not None: if new_shape is not None:
batch_mean = tf.reshape(batch_mean, new_shape)
batch_var = tf.reshape(batch_var, new_shape)
# Using fused_batch_norm(is_training=False) is actually slightly faster, # Using fused_batch_norm(is_training=False) is actually slightly faster,
# but hopefully this call will be JITed in the future. # but hopefully this call will be JITed in the future.
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization(
...@@ -348,6 +369,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -348,6 +369,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization(
inputs, batch_mean, batch_var, inputs, batch_mean, batch_var,
beta, gamma, epsilon) beta, gamma, epsilon)
if virtual_batch_size is not None:
xn = tf.reshape(xn, orig_shape)
if do_ema_update: if do_ema_update:
ret = internal_update_bn_ema( ret = internal_update_bn_ema(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment