Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e0a7e8f9
Commit
e0a7e8f9
authored
Mar 11, 2020
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Better support for virtual_batch_size
parent
9f4600a0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
99 additions
and
76 deletions
+99
-76
README.md
README.md
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+98
-75
No files found.
README.md
View file @
e0a7e8f9
...
@@ -67,7 +67,7 @@ Dependencies:
...
@@ -67,7 +67,7 @@ Dependencies:
+
Python 3.3+.
+
Python 3.3+.
+
Python bindings for OpenCV. (Optional, but required by a lot of features)
+
Python bindings for OpenCV. (Optional, but required by a lot of features)
+
TensorFlow ≥ 1.
3
, < 2. (Not required if you only want to use
`tensorpack.dataflow`
alone as a data processing library)
+
TensorFlow ≥ 1.
5
, < 2. (Not required if you only want to use
`tensorpack.dataflow`
alone as a data processing library)
```
```
pip install --upgrade git+https://github.com/tensorpack/tensorpack.git
pip install --upgrade git+https://github.com/tensorpack/tensorpack.git
# or add `--user` to install to user's local directories
# or add `--user` to install to user's local directories
...
...
tensorpack/models/batch_norm.py
View file @
e0a7e8f9
...
@@ -11,7 +11,6 @@ from ..tfutils.common import get_tf_version_tuple
...
@@ -11,7 +11,6 @@ from ..tfutils.common import get_tf_version_tuple
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
,
log_once
from
..utils.argtools
import
get_data_format
,
log_once
from
..utils.develop
import
log_deprecated
from
.common
import
VariableHolder
,
layer_register
from
.common
import
VariableHolder
,
layer_register
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
from
.utils
import
disable_autograph
from
.utils
import
disable_autograph
...
@@ -60,6 +59,59 @@ def internal_update_bn_ema(xn, batch_mean, batch_var,
...
@@ -60,6 +59,59 @@ def internal_update_bn_ema(xn, batch_mean, batch_var,
return
tf
.
identity
(
xn
,
name
=
'output'
)
return
tf
.
identity
(
xn
,
name
=
'output'
)
def
get_sync_bn_mean_var
(
inputs
,
red_axis
,
sync_statistics
):
ctx
=
get_current_tower_context
()
batch_mean
=
tf
.
reduce_mean
(
inputs
,
axis
=
red_axis
)
batch_mean_square
=
tf
.
reduce_mean
(
tf
.
square
(
inputs
),
axis
=
red_axis
)
TF_version
=
get_tf_version_tuple
()
if
sync_statistics
==
'nccl'
:
num_dev
=
ctx
.
total
if
num_dev
==
1
:
logger
.
warn
(
"BatchNorm(sync_statistics='nccl') is used with only one tower!"
)
else
:
assert
TF_version
>=
(
1
,
10
),
\
"Cross-GPU BatchNorm is only supported in TF>=1.10 ."
\
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
if
TF_version
<=
(
1
,
12
):
try
:
from
tensorflow.contrib.nccl.python.ops.nccl_ops
import
_validate_and_load_nccl_so
# deprecated
except
Exception
:
pass
else
:
_validate_and_load_nccl_so
()
from
tensorflow.contrib.nccl.ops
import
gen_nccl_ops
# deprecated
else
:
from
tensorflow.python.ops
import
gen_nccl_ops
shared_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
tf
.
get_variable_scope
()
.
name
)
batch_mean
=
gen_nccl_ops
.
nccl_all_reduce
(
input
=
batch_mean
,
reduction
=
'sum'
,
num_devices
=
num_dev
,
shared_name
=
shared_name
+
'_NCCL_mean'
)
*
(
1.0
/
num_dev
)
batch_mean_square
=
gen_nccl_ops
.
nccl_all_reduce
(
input
=
batch_mean_square
,
reduction
=
'sum'
,
num_devices
=
num_dev
,
shared_name
=
shared_name
+
'_NCCL_mean_square'
)
*
(
1.0
/
num_dev
)
elif
sync_statistics
==
'horovod'
:
# Require https://github.com/uber/horovod/pull/331
import
horovod.tensorflow
as
hvd
if
hvd
.
size
()
==
1
:
logger
.
warn
(
"BatchNorm(sync_statistics='horovod') is used with only one process!"
)
else
:
import
horovod
hvd_version
=
tuple
(
map
(
int
,
horovod
.
__version__
.
split
(
'.'
)[:
3
]))
assert
hvd_version
>=
(
0
,
13
,
6
),
"sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean
=
hvd
.
allreduce
(
batch_mean
,
average
=
True
)
batch_mean_square
=
hvd
.
allreduce
(
batch_mean_square
,
average
=
True
)
batch_var
=
batch_mean_square
-
tf
.
square
(
batch_mean
)
return
batch_mean
,
batch_var
@
layer_register
()
@
layer_register
()
@
convert_to_tflayer_args
(
@
convert_to_tflayer_args
(
args_names
=
[],
args_names
=
[],
...
@@ -78,8 +130,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -78,8 +130,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
virtual_batch_size
=
None
,
virtual_batch_size
=
None
,
data_format
=
'channels_last'
,
data_format
=
'channels_last'
,
ema_update
=
'default'
,
ema_update
=
'default'
,
sync_statistics
=
None
,
sync_statistics
=
None
):
internal_update
=
None
):
"""
"""
A more powerful version of `tf.layers.batch_normalization`. It differs from
A more powerful version of `tf.layers.batch_normalization`. It differs from
the offical one in the following aspects:
the offical one in the following aspects:
...
@@ -90,11 +141,19 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -90,11 +141,19 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
User-provided value can overwrite this behavior.
User-provided value can overwrite this behavior.
4. Support the ``ema_update`` option, which covers broader use cases than the standard EMA update.
4. Support the ``ema_update`` option, which covers broader use cases than the standard EMA update.
5. Support the ``sync_statistics`` option, which implements "SyncBN" and is very useful in small-batch models.
5. Support the ``sync_statistics`` option, which implements "SyncBN" and is very useful in small-batch models.
6. Better support of the ``virtual_batch_size`` option that does not have the bugs in ``tf.layers``.
Args:
Args:
training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA
training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA
to normalize. By default, it is equal to `get_current_tower_context().is_training`.
to normalize. By default, it is equal to `get_current_tower_context().is_training`.
This is not a good argument name, but it is what the Tensorflow layer uses.
This is not a good argument name, but it is what the Tensorflow layer uses.
virtual_batch_size (int): implement "Ghost BatchNorm" that normalizes
the data with a smaller batch size than the input. Only effective when training is True.
The value has to be a divisor of the actual batch size.
It does not use the buggy TensorFlow implementation which has the
problems of (1) wrong behavior at inference; (2) create variables with unnecessary size=1 dimensions.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/23050
ema_update (str): Only effective when ``training=True``. It has the following options:
ema_update (str): Only effective when ``training=True``. It has the following options:
* "default": same as "collection". Because this is the default behavior in TensorFlow.
* "default": same as "collection". Because this is the default behavior in TensorFlow.
...
@@ -128,7 +187,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -128,7 +187,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
* "horovod": this layer must be used under tensorpack's :class:`HorovodTrainer`.
* "horovod": this layer must be used under tensorpack's :class:`HorovodTrainer`.
It uses the aggregated statistics of the whole batch (across all MPI ranks) to normalize.
It uses the aggregated statistics of the whole batch (across all MPI ranks) to normalize.
Note that on
single machine this is significantly
slower than the "nccl" implementation.
Note that on
a single machine this is found to be
slower than the "nccl" implementation.
When not None, each GPU computes its own E[x] and E[x^2],
When not None, each GPU computes its own E[x] and E[x^2],
which are then averaged among all GPUs to compute global mean & variance.
which are then averaged among all GPUs to compute global mean & variance.
...
@@ -151,8 +210,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -151,8 +210,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
When `sync_statistics` is enabled, `ema_update` is set to "internal" automatically.
When `sync_statistics` is enabled, `ema_update` is set to "internal" automatically.
This is to avoid running `UPDATE_OPS`, which requires synchronization.
This is to avoid running `UPDATE_OPS`, which requires synchronization.
internal_update: deprecated option. Don't use.
Variable Names:
Variable Names:
* ``beta``: the bias term. Will be zero-inited by default.
* ``beta``: the bias term. Will be zero-inited by default.
...
@@ -175,6 +232,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -175,6 +232,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
if
training
is
None
:
if
training
is
None
:
training
=
ctx
.
is_training
training
=
ctx
.
is_training
training
=
bool
(
training
)
training
=
bool
(
training
)
if
not
training
:
virtual_batch_size
=
None
# parse shapes
# parse shapes
data_format
=
get_data_format
(
data_format
,
keras_mode
=
False
)
data_format
=
get_data_format
(
data_format
,
keras_mode
=
False
)
...
@@ -186,11 +245,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -186,11 +245,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
assert
sync_statistics
in
[
None
,
'nccl'
,
'horovod'
],
sync_statistics
assert
sync_statistics
in
[
None
,
'nccl'
,
'horovod'
],
sync_statistics
assert
ema_update
in
[
"default"
,
"collection"
,
"internal"
,
"skip"
]
assert
ema_update
in
[
"default"
,
"collection"
,
"internal"
,
"skip"
]
if
internal_update
is
not
None
:
log_deprecated
(
"BatchNorm(internal_update=)"
,
"Use ema_update='internal' instead!"
,
"2020-01-01"
)
assert
ema_update
==
'default'
,
\
"Do not use internal_update and ema_update together! internal_update is deprecated"
ema_update
=
"internal"
if
internal_update
else
"collection"
if
ema_update
==
"default"
:
if
ema_update
==
"default"
:
ema_update
=
"collection"
ema_update
=
"collection"
# Logic:
# Logic:
...
@@ -211,12 +265,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -211,12 +265,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
assert
axis
in
[
1
,
3
],
axis
assert
axis
in
[
1
,
3
],
axis
num_chan
=
shape
[
axis
]
num_chan
=
shape
[
axis
]
TF_version
=
get_tf_version_tuple
()
freeze_bn_backward
=
not
training
and
ctx
.
is_training
freeze_bn_backward
=
not
training
and
ctx
.
is_training
if
freeze_bn_backward
:
if
freeze_bn_backward
:
assert
TF_version
>=
(
1
,
4
),
\
"Fine tuning a BatchNorm model with fixed statistics needs TF>=1.4!"
if
ctx
.
is_main_training_tower
:
# only warn in first tower
if
ctx
.
is_main_training_tower
:
# only warn in first tower
log_once
(
"Some BatchNorm layer uses moving_mean/moving_variance in training."
,
func
=
'warn'
)
log_once
(
"Some BatchNorm layer uses moving_mean/moving_variance in training."
,
func
=
'warn'
)
# Using moving_mean/moving_variance in training, which means we
# Using moving_mean/moving_variance in training, which means we
...
@@ -224,8 +274,9 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -224,8 +274,9 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
do_sync_bn
=
(
sync_statistics
is
not
None
)
and
training
do_sync_bn
=
(
sync_statistics
is
not
None
)
and
training
if
not
do_sync_bn
:
if
not
do_sync_bn
and
not
virtual_batch_size
:
# Use the builtin layer for anything except for sync-bn
# Use the builtin layer for regular per-GPU BN.
# Use our own implementation for SyncBN and GhostBN
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
with
rename_get_variable
(
with
rename_get_variable
(
{
'moving_mean'
:
'mean/EMA'
,
{
'moving_mean'
:
'mean/EMA'
,
...
@@ -239,10 +290,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -239,10 +290,6 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
# https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429
# https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429
fused
=
(
ndims
==
4
and
axis
in
[
1
,
3
]
and
not
freeze_bn_backward
),
fused
=
(
ndims
==
4
and
axis
in
[
1
,
3
]
and
not
freeze_bn_backward
),
_reuse
=
tf
.
get_variable_scope
()
.
reuse
)
_reuse
=
tf
.
get_variable_scope
()
.
reuse
)
if
TF_version
>=
(
1
,
5
):
tf_args
[
'virtual_batch_size'
]
=
virtual_batch_size
else
:
assert
virtual_batch_size
is
None
,
"Feature not supported in this version of TF!"
use_fp16
=
inputs
.
dtype
==
tf
.
float16
use_fp16
=
inputs
.
dtype
==
tf
.
float16
if
use_fp16
:
if
use_fp16
:
# non-fused does not support fp16; fused does not support all layouts.
# non-fused does not support fp16; fused does not support all layouts.
...
@@ -279,65 +326,39 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -279,65 +326,39 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
vh
.
beta
=
layer
.
beta
vh
.
beta
=
layer
.
beta
else
:
else
:
red_axis
=
[
0
]
if
ndims
==
2
else
([
0
,
2
,
3
]
if
axis
==
1
else
[
0
,
1
,
2
])
red_axis
=
[
0
]
if
ndims
==
2
else
([
0
,
2
,
3
]
if
axis
==
1
else
[
0
,
1
,
2
])
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
num_chan
,
scale
,
center
,
beta_initializer
,
gamma_initializer
)
assert
sync_statistics
is
None
or
virtual_batch_size
is
None
,
"Cannot use SyncBN and GhostBN together!"
new_shape
=
None
# don't need to reshape unless ...
new_shape
=
None
# don't need to reshape unless ...
if
ndims
==
4
and
axis
==
1
:
new_shape
=
[
1
,
num_chan
,
1
,
1
]
batch_mean
=
tf
.
reduce_mean
(
inputs
,
axis
=
red_axis
)
if
sync_statistics
is
not
None
:
batch_mean_square
=
tf
.
reduce_mean
(
tf
.
square
(
inputs
),
axis
=
red_axis
)
# sync bn
batch_mean
,
batch_var
=
get_sync_bn_mean_var
(
inputs
,
red_axis
)
batch_mean_vec
=
batch_mean
batch_var_vec
=
batch_var
if
sync_statistics
==
'nccl'
:
if
ndims
==
4
and
axis
==
1
:
num_dev
=
ctx
.
total
new_shape
=
[
1
,
num_chan
,
1
,
1
]
if
num_dev
==
1
:
batch_mean
=
tf
.
reshape
(
batch_mean
,
new_shape
)
logger
.
warn
(
"BatchNorm(sync_statistics='nccl') is used with only one tower!"
)
batch_var
=
tf
.
reshape
(
batch_var
,
new_shape
)
else
:
else
:
assert
TF_version
>=
(
1
,
10
),
\
orig_shape
=
tf
.
shape
(
inputs
)
"Cross-GPU BatchNorm is only supported in TF>=1.10 ."
\
inputs
=
tf
.
reshape
(
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
inputs
,
tf
.
concat
([[
-
1
,
virtual_batch_size
],
if
TF_version
<=
(
1
,
12
):
tf
.
shape
(
inputs
)[
1
:]],
axis
=
0
))
try
:
# B/V, V, ...
from
tensorflow.contrib.nccl.python.ops.nccl_ops
import
_validate_and_load_nccl_so
# deprecated
red_axis
=
[
x
+
1
for
x
in
red_axis
]
except
Exception
:
new_shape
=
[
1
]
*
(
ndims
+
1
)
pass
new_shape
[
axis
+
1
]
=
num_chan
else
:
_validate_and_load_nccl_so
()
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
inputs
,
red_axis
,
keepdims
=
True
)
from
tensorflow.contrib.nccl.ops
import
gen_nccl_ops
# deprecated
# B/V, C
else
:
# vec for EMA update: use the first one only to mimic per-GPU BN
from
tensorflow.python.ops
import
gen_nccl_ops
batch_mean_vec
=
tf
.
reshape
(
batch_mean
[
0
],
[
num_chan
])
shared_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
tf
.
get_variable_scope
()
.
name
)
batch_var_vec
=
tf
.
reshape
(
batch_var
[
0
],
[
num_chan
])
batch_mean
=
gen_nccl_ops
.
nccl_all_reduce
(
input
=
batch_mean
,
reduction
=
'sum'
,
num_devices
=
num_dev
,
shared_name
=
shared_name
+
'_NCCL_mean'
)
*
(
1.0
/
num_dev
)
batch_mean_square
=
gen_nccl_ops
.
nccl_all_reduce
(
input
=
batch_mean_square
,
reduction
=
'sum'
,
num_devices
=
num_dev
,
shared_name
=
shared_name
+
'_NCCL_mean_square'
)
*
(
1.0
/
num_dev
)
elif
sync_statistics
==
'horovod'
:
# Require https://github.com/uber/horovod/pull/331
import
horovod.tensorflow
as
hvd
if
hvd
.
size
()
==
1
:
logger
.
warn
(
"BatchNorm(sync_statistics='horovod') is used with only one process!"
)
else
:
import
horovod
hvd_version
=
tuple
(
map
(
int
,
horovod
.
__version__
.
split
(
'.'
)[:
3
]))
assert
hvd_version
>=
(
0
,
13
,
6
),
"sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean
=
hvd
.
allreduce
(
batch_mean
,
average
=
True
)
batch_mean_square
=
hvd
.
allreduce
(
batch_mean_square
,
average
=
True
)
batch_var
=
batch_mean_square
-
tf
.
square
(
batch_mean
)
batch_mean_vec
=
batch_mean
batch_var_vec
=
batch_var
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
num_chan
,
scale
,
center
,
beta_initializer
,
gamma_initializer
)
if
new_shape
is
not
None
:
if
new_shape
is
not
None
:
batch_mean
=
tf
.
reshape
(
batch_mean
,
new_shape
)
batch_var
=
tf
.
reshape
(
batch_var
,
new_shape
)
# Using fused_batch_norm(is_training=False) is actually slightly faster,
# Using fused_batch_norm(is_training=False) is actually slightly faster,
# but hopefully this call will be JITed in the future.
# but hopefully this call will be JITed in the future.
xn
=
tf
.
nn
.
batch_normalization
(
xn
=
tf
.
nn
.
batch_normalization
(
...
@@ -348,6 +369,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -348,6 +369,8 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
xn
=
tf
.
nn
.
batch_normalization
(
xn
=
tf
.
nn
.
batch_normalization
(
inputs
,
batch_mean
,
batch_var
,
inputs
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
)
beta
,
gamma
,
epsilon
)
if
virtual_batch_size
is
not
None
:
xn
=
tf
.
reshape
(
xn
,
orig_shape
)
if
do_ema_update
:
if
do_ema_update
:
ret
=
internal_update_bn_ema
(
ret
=
internal_update_bn_ema
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment