Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e0a7e8f9
Commit
e0a7e8f9
authored
Mar 11, 2020
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Better support for virtual_batch_size
parent
9f4600a0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
99 additions
and
76 deletions
+99
-76
README.md
README.md
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+98
-75
No files found.
README.md
View file @
e0a7e8f9
...
...
@@ -67,7 +67,7 @@ Dependencies:
+
Python 3.3+.
+
Python bindings for OpenCV. (Optional, but required by a lot of features)
+
TensorFlow ≥ 1.
3
, < 2. (Not required if you only want to use
`tensorpack.dataflow`
alone as a data processing library)
+
TensorFlow ≥ 1.
5
, < 2. (Not required if you only want to use
`tensorpack.dataflow`
alone as a data processing library)
```
pip install --upgrade git+https://github.com/tensorpack/tensorpack.git
# or add `--user` to install to user's local directories
...
...
tensorpack/models/batch_norm.py
View file @
e0a7e8f9
...
...
@@ -11,7 +11,6 @@ from ..tfutils.common import get_tf_version_tuple
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
,
log_once
from
..utils.develop
import
log_deprecated
from
.common
import
VariableHolder
,
layer_register
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
from
.utils
import
disable_autograph
...
...
@@ -60,6 +59,59 @@ def internal_update_bn_ema(xn, batch_mean, batch_var,
return
tf
.
identity
(
xn
,
name
=
'output'
)
def
get_sync_bn_mean_var
(
inputs
,
red_axis
,
sync_statistics
):
ctx
=
get_current_tower_context
()
batch_mean
=
tf
.
reduce_mean
(
inputs
,
axis
=
red_axis
)
batch_mean_square
=
tf
.
reduce_mean
(
tf
.
square
(
inputs
),
axis
=
red_axis
)
TF_version
=
get_tf_version_tuple
()
if
sync_statistics
==
'nccl'
:
num_dev
=
ctx
.
total
if
num_dev
==
1
:
logger
.
warn
(
"BatchNorm(sync_statistics='nccl') is used with only one tower!"
)
else
:
assert
TF_version
>=
(
1
,
10
),
\
"Cross-GPU BatchNorm is only supported in TF>=1.10 ."
\
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
if
TF_version
<=
(
1
,
12
):
try
:
from
tensorflow.contrib.nccl.python.ops.nccl_ops
import
_validate_and_load_nccl_so
# deprecated
except
Exception
:
pass
else
:
_validate_and_load_nccl_so
()
from
tensorflow.contrib.nccl.ops
import
gen_nccl_ops
# deprecated
else
:
from
tensorflow.python.ops
import
gen_nccl_ops
shared_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
tf
.
get_variable_scope
()
.
name
)
batch_mean
=
gen_nccl_ops
.
nccl_all_reduce
(
input
=
batch_mean
,
reduction
=
'sum'
,
num_devices
=
num_dev
,
shared_name
=
shared_name
+
'_NCCL_mean'
)
*
(
1.0
/
num_dev
)
batch_mean_square
=
gen_nccl_ops
.
nccl_all_reduce
(
input
=
batch_mean_square
,
reduction
=
'sum'
,
num_devices
=
num_dev
,
shared_name
=
shared_name
+
'_NCCL_mean_square'
)
*
(
1.0
/
num_dev
)
elif
sync_statistics
==
'horovod'
:
# Require https://github.com/uber/horovod/pull/331
import
horovod.tensorflow
as
hvd
if
hvd
.
size
()
==
1
:
logger
.
warn
(
"BatchNorm(sync_statistics='horovod') is used with only one process!"
)
else
:
import
horovod
hvd_version
=
tuple
(
map
(
int
,
horovod
.
__version__
.
split
(
'.'
)[:
3
]))
assert
hvd_version
>=
(
0
,
13
,
6
),
"sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean
=
hvd
.
allreduce
(
batch_mean
,
average
=
True
)
batch_mean_square
=
hvd
.
allreduce
(
batch_mean_square
,
average
=
True
)
batch_var
=
batch_mean_square
-
tf
.
square
(
batch_mean
)
return
batch_mean
,
batch_var
@
layer_register
()
@
convert_to_tflayer_args
(
args_names
=
[],
...
...
@@ -78,8 +130,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
virtual_batch_size
=
None
,
data_format
=
'channels_last'
,
ema_update
=
'default'
,
sync_statistics
=
None
,
internal_update
=
None
):
sync_statistics
=
None
):
"""
A more powerful version of `tf.layers.batch_normalization`. It differs from
the offical one in the following aspects:
...
...
@@ -90,11 +141,19 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
User-provided value can overwrite this behavior.
4. Support the ``ema_update`` option, which covers broader use cases than the standard EMA update.
5. Support the ``sync_statistics`` option, which implements "SyncBN" and is very useful in small-batch models.
6. Better support of the ``virtual_batch_size`` option that does not have the bugs in ``tf.layers``.
Args:
training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA
to normalize. By default, it is equal to `get_current_tower_context().is_training`.
This is not a good argument name, but it is what the Tensorflow layer uses.
virtual_batch_size (int): implement "Ghost BatchNorm" that normalizes
the data with a smaller batch size than the input. Only effective when training is True.
The value has to be a divisor of the actual batch size.
It does not use the buggy TensorFlow implementation which has the
problems of (1) wrong behavior at inference; (2) create variables with unnecessary size=1 dimensions.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/23050
ema_update (str): Only effective when ``training=True``. It has the following options:
* "default": same as "collection". Because this is the default behavior in TensorFlow.
...
...
@@ -128,7 +187,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
* "horovod": this layer must be used under tensorpack's :class:`HorovodTrainer`.
It uses the aggregated statistics of the whole batch (across all MPI ranks) to normalize.
Note that on
single machine this is significantly
slower than the "nccl" implementation.
Note that on
a single machine this is found to be
slower than the "nccl" implementation.
When not None, each GPU computes its own E[x] and E[x^2],
which are then averaged among all GPUs to compute global mean & variance.
...
...
@@ -151,8 +210,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
When `sync_statistics` is enabled, `ema_update` is set to "internal" automatically.
This is to avoid running `UPDATE_OPS`, which requires synchronization.
internal_update: deprecated option. Don't use.
Variable Names:
* ``beta``: the bias term. Will be zero-inited by default.
...
...
@@ -175,6 +232,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
if
training
is
None
:
training
=
ctx
.
is_training
training
=
bool
(
training
)
if
not
training
:
virtual_batch_size
=
None
# parse shapes
data_format
=
get_data_format
(
data_format
,
keras_mode
=
False
)
...
...
@@ -186,11 +245,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
assert
sync_statistics
in
[
None
,
'nccl'
,
'horovod'
],
sync_statistics
assert
ema_update
in
[
"default"
,
"collection"
,
"internal"
,
"skip"
]
if
internal_update
is
not
None
:
log_deprecated
(
"BatchNorm(internal_update=)"
,
"Use ema_update='internal' instead!"
,
"2020-01-01"
)
assert
ema_update
==
'default'
,
\
"Do not use internal_update and ema_update together! internal_update is deprecated"
ema_update
=
"internal"
if
internal_update
else
"collection"
if
ema_update
==
"default"
:
ema_update
=
"collection"
# Logic:
...
...
@@ -211,12 +265,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
assert
axis
in
[
1
,
3
],
axis
num_chan
=
shape
[
axis
]
TF_version
=
get_tf_version_tuple
()
freeze_bn_backward
=
not
training
and
ctx
.
is_training
if
freeze_bn_backward
:
assert
TF_version
>=
(
1
,
4
),
\
"Fine tuning a BatchNorm model with fixed statistics needs TF>=1.4!"
if
ctx
.
is_main_training_tower
:
# only warn in first tower
log_once
(
"Some BatchNorm layer uses moving_mean/moving_variance in training."
,
func
=
'warn'
)
# Using moving_mean/moving_variance in training, which means we
...
...
@@ -224,8 +274,9 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
do_sync_bn
=
(
sync_statistics
is
not
None
)
and
training
if
not
do_sync_bn
:
# Use the builtin layer for anything except for sync-bn
if
not
do_sync_bn
and
not
virtual_batch_size
:
# Use the builtin layer for regular per-GPU BN.
# Use our own implementation for SyncBN and GhostBN
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
with
rename_get_variable
(
{
'moving_mean'
:
'mean/EMA'
,
...
...
@@ -239,10 +290,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
# https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429
fused
=
(
ndims
==
4
and
axis
in
[
1
,
3
]
and
not
freeze_bn_backward
),
_reuse
=
tf
.
get_variable_scope
()
.
reuse
)
if
TF_version
>=
(
1
,
5
):
tf_args
[
'virtual_batch_size'
]
=
virtual_batch_size
else
:
assert
virtual_batch_size
is
None
,
"Feature not supported in this version of TF!"
use_fp16
=
inputs
.
dtype
==
tf
.
float16
if
use_fp16
:
# non-fused does not support fp16; fused does not support all layouts.
...
...
@@ -279,65 +326,39 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
vh
.
beta
=
layer
.
beta
else
:
red_axis
=
[
0
]
if
ndims
==
2
else
([
0
,
2
,
3
]
if
axis
==
1
else
[
0
,
1
,
2
])
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
num_chan
,
scale
,
center
,
beta_initializer
,
gamma_initializer
)
assert
sync_statistics
is
None
or
virtual_batch_size
is
None
,
"Cannot use SyncBN and GhostBN together!"
new_shape
=
None
# don't need to reshape unless ...
if
ndims
==
4
and
axis
==
1
:
new_shape
=
[
1
,
num_chan
,
1
,
1
]
batch_mean
=
tf
.
reduce_mean
(
inputs
,
axis
=
red_axis
)
batch_mean_square
=
tf
.
reduce_mean
(
tf
.
square
(
inputs
),
axis
=
red_axis
)
if
sync_statistics
is
not
None
:
# sync bn
batch_mean
,
batch_var
=
get_sync_bn_mean_var
(
inputs
,
red_axis
)
batch_mean_vec
=
batch_mean
batch_var_vec
=
batch_var
if
sync_statistics
==
'nccl'
:
num_dev
=
ctx
.
total
if
num_dev
==
1
:
logger
.
warn
(
"BatchNorm(sync_statistics='nccl') is used with only one tower!"
)
else
:
assert
TF_version
>=
(
1
,
10
),
\
"Cross-GPU BatchNorm is only supported in TF>=1.10 ."
\
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
if
TF_version
<=
(
1
,
12
):
try
:
from
tensorflow.contrib.nccl.python.ops.nccl_ops
import
_validate_and_load_nccl_so
# deprecated
except
Exception
:
pass
else
:
_validate_and_load_nccl_so
()
from
tensorflow.contrib.nccl.ops
import
gen_nccl_ops
# deprecated
else
:
from
tensorflow.python.ops
import
gen_nccl_ops
shared_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
tf
.
get_variable_scope
()
.
name
)
batch_mean
=
gen_nccl_ops
.
nccl_all_reduce
(
input
=
batch_mean
,
reduction
=
'sum'
,
num_devices
=
num_dev
,
shared_name
=
shared_name
+
'_NCCL_mean'
)
*
(
1.0
/
num_dev
)
batch_mean_square
=
gen_nccl_ops
.
nccl_all_reduce
(
input
=
batch_mean_square
,
reduction
=
'sum'
,
num_devices
=
num_dev
,
shared_name
=
shared_name
+
'_NCCL_mean_square'
)
*
(
1.0
/
num_dev
)
elif
sync_statistics
==
'horovod'
:
# Require https://github.com/uber/horovod/pull/331
import
horovod.tensorflow
as
hvd
if
hvd
.
size
()
==
1
:
logger
.
warn
(
"BatchNorm(sync_statistics='horovod') is used with only one process!"
)
else
:
import
horovod
hvd_version
=
tuple
(
map
(
int
,
horovod
.
__version__
.
split
(
'.'
)[:
3
]))
assert
hvd_version
>=
(
0
,
13
,
6
),
"sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean
=
hvd
.
allreduce
(
batch_mean
,
average
=
True
)
batch_mean_square
=
hvd
.
allreduce
(
batch_mean_square
,
average
=
True
)
batch_var
=
batch_mean_square
-
tf
.
square
(
batch_mean
)
batch_mean_vec
=
batch_mean
batch_var_vec
=
batch_var
if
ndims
==
4
and
axis
==
1
:
new_shape
=
[
1
,
num_chan
,
1
,
1
]
batch_mean
=
tf
.
reshape
(
batch_mean
,
new_shape
)
batch_var
=
tf
.
reshape
(
batch_var
,
new_shape
)
else
:
orig_shape
=
tf
.
shape
(
inputs
)
inputs
=
tf
.
reshape
(
inputs
,
tf
.
concat
([[
-
1
,
virtual_batch_size
],
tf
.
shape
(
inputs
)[
1
:]],
axis
=
0
))
# B/V, V, ...
red_axis
=
[
x
+
1
for
x
in
red_axis
]
new_shape
=
[
1
]
*
(
ndims
+
1
)
new_shape
[
axis
+
1
]
=
num_chan
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
inputs
,
red_axis
,
keepdims
=
True
)
# B/V, C
# vec for EMA update: use the first one only to mimic per-GPU BN
batch_mean_vec
=
tf
.
reshape
(
batch_mean
[
0
],
[
num_chan
])
batch_var_vec
=
tf
.
reshape
(
batch_var
[
0
],
[
num_chan
])
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
num_chan
,
scale
,
center
,
beta_initializer
,
gamma_initializer
)
if
new_shape
is
not
None
:
batch_mean
=
tf
.
reshape
(
batch_mean
,
new_shape
)
batch_var
=
tf
.
reshape
(
batch_var
,
new_shape
)
# Using fused_batch_norm(is_training=False) is actually slightly faster,
# but hopefully this call will be JITed in the future.
xn
=
tf
.
nn
.
batch_normalization
(
...
...
@@ -348,6 +369,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
xn
=
tf
.
nn
.
batch_normalization
(
inputs
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
)
if
virtual_batch_size
is
not
None
:
xn
=
tf
.
reshape
(
xn
,
orig_shape
)
if
do_ema_update
:
ret
=
internal_update_bn_ema
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment