Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f6ede612
Commit
f6ede612
authored
May 11, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Better BatchNorm (with ema_update option decoupled from training)
parent
4a46b93d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
43 deletions
+72
-43
examples/GAN/GAN.py
examples/GAN/GAN.py
+2
-2
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+70
-41
No files found.
examples/GAN/GAN.py
View file @
f6ede612
...
@@ -169,8 +169,8 @@ class SeparateGANTrainer(TowerTrainer):
...
@@ -169,8 +169,8 @@ class SeparateGANTrainer(TowerTrainer):
# Build the graph
# Build the graph
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input_signature
())
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input_signature
())
with
TowerContext
(
''
,
is_training
=
True
),
\
with
TowerContext
(
''
,
is_training
=
True
),
\
argscope
(
BatchNorm
,
internal_update
=
True
):
argscope
(
BatchNorm
,
ema_update
=
'internal'
):
# should not hook the updates to both train_op, it will hurt training speed.
# should not hook the
EMA
updates to both train_op, it will hurt training speed.
self
.
tower_func
(
*
input
.
get_input_tensors
())
self
.
tower_func
(
*
input
.
get_input_tensors
())
update_ops
=
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
)
update_ops
=
tf
.
get_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
)
if
len
(
update_ops
):
if
len
(
update_ops
):
...
...
tensorpack/models/batch_norm.py
View file @
f6ede612
...
@@ -12,6 +12,7 @@ from ..tfutils.common import get_tf_version_tuple
...
@@ -12,6 +12,7 @@ from ..tfutils.common import get_tf_version_tuple
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
from
..utils.argtools
import
get_data_format
from
..utils.develop
import
log_deprecated
from
.common
import
VariableHolder
,
layer_register
from
.common
import
VariableHolder
,
layer_register
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
...
@@ -39,8 +40,8 @@ def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init):
...
@@ -39,8 +40,8 @@ def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init):
return
beta
,
gamma
,
moving_mean
,
moving_var
return
beta
,
gamma
,
moving_mean
,
moving_var
def
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
def
internal_
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
):
moving_mean
,
moving_var
,
decay
):
update_op1
=
moving_averages
.
assign_moving_average
(
update_op1
=
moving_averages
.
assign_moving_average
(
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
name
=
'mean_ema_op'
)
name
=
'mean_ema_op'
)
...
@@ -71,8 +72,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -71,8 +72,9 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
gamma_initializer
=
tf
.
ones_initializer
(),
gamma_initializer
=
tf
.
ones_initializer
(),
virtual_batch_size
=
None
,
virtual_batch_size
=
None
,
data_format
=
'channels_last'
,
data_format
=
'channels_last'
,
internal_update
=
False
,
ema_update
=
'default'
,
sync_statistics
=
None
):
sync_statistics
=
None
,
internal_update
=
None
):
"""
"""
Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
in the following:
in the following:
...
@@ -80,21 +82,29 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -80,21 +82,29 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
2. Default value for `momentum` and `epsilon` is different.
2. Default value for `momentum` and `epsilon` is different.
3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
4. Support the ``
internal_update`` option, which cover more use cases than the standard collection-based
update.
4. Support the ``
ema_update`` option, which cover more use cases than the standard EMA
update.
5. Support the ``sync_statistics`` option, which is very useful in small-batch models.
5. Support the ``sync_statistics`` option, which i
mplements "SyncBN" and i
s very useful in small-batch models.
Args:
Args:
internal_update (bool): if False, add EMA update ops to
training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer by control dependencies.
to normalize. By default, it is equal to `get_current_tower_context().is_training`.
They are very similar in speed, but `internal_update=True` is recommended and can be helpful when:
This is not a good argument name, but it is what the Tensorflow layer uses.
ema_update (str): Only effective when ``training=True``. It has the following options:
1. BatchNorm is used inside dynamic control flow.
* "default": same as "collection". Because this is the default behavior in tensorflow.
The collection-based update does not support dynamic control flows.
* "skip": do not update EMA.
2. BatchNorm layer is sometimes unused (e.g., when you have two networks to train alternatively).
* "collection": Add EMA update ops to collection `tf.GraphKeys.UPDATE_OPS`.
Putting all update ops into a single collection will waste a lot of compute.
The ops in the collection will be run automatically by the callback :class:`RunUpdateOps`.
* "internal": EMA is updated inside this layer itself by control dependencies.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
It has similar speed to "collection", but "internal" is recommended and can be helpful when:
sync_statistics (str or None): one of None, "nccl", or "horovod".
1. BatchNorm is used inside dynamic control flow.
The collection-based update does not support dynamic control flows.
2. BatchNorm layer is sometimes unused (e.g., when you have two networks to train alternatively).
Putting all update ops into a single collection will waste a lot of compute.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
sync_statistics (str or None): one of None, "nccl", or "horovod". It determines how to compute the
"per-batch statistics" when ``training==True``.
By default (None), it uses statistics of the input tensor to normalize during training.
By default (None), it uses statistics of the input tensor to normalize during training.
This is the standard way BatchNorm was implemented in most frameworks.
This is the standard way BatchNorm was implemented in most frameworks.
...
@@ -119,15 +129,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -119,15 +129,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
If different GPUs execute one BatchNorm layer for different number of times
If different GPUs execute one BatchNorm layer for different number of times
(e.g., if some GPUs do not execute it), this layer may hang.
(e.g., if some GPUs do not execute it), this layer may hang.
This option only has effect when `training == get_current_tower_context().training == True`.
This option is also known as "SyncBN" or Cross-GPU BatchNorm" as mentioned in:
This option is also known as "Cross-GPU BatchNorm" as mentioned in:
`MegDet: A Large Mini-Batch Object Detector <https://arxiv.org/abs/1711.07240>`_.
`MegDet: A Large Mini-Batch Object Detector <https://arxiv.org/abs/1711.07240>`_.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222.
Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222.
When `sync_statistics` is enabled, `
internal_update` will be set to True
automatically.
When `sync_statistics` is enabled, `
ema_update` is set to "internal"
automatically.
This is to avoid running `UPDATE_OPS`, which requires synchronization.
This is to avoid running `UPDATE_OPS`, which requires synchronization.
internal_update: deprecated option. Don't use.
Variable Names:
Variable Names:
* ``beta``: the bias term. Will be zero-inited by default.
* ``beta``: the bias term. Will be zero-inited by default.
...
@@ -136,16 +146,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -136,16 +146,15 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
* ``variance/EMA``: the moving average of variance.
* ``variance/EMA``: the moving average of variance.
Note:
Note:
Combinations of ``training`` and ``ctx.is_training``:
This layer is more flexible than the standard "BatchNorm" layer and provides more features:
1. No matter whether you're doing training or not, you can set the `training` argument
* ``training == ctx.is_training``: standard BN, EMA are maintained during training
to use batch statistics / EMA statistics.
and used during inference. This is the default.
i.e., you can use batch statistics during inference, or use EMA statistics during training.
* ``training and not ctx.is_training``: still use batch statistics in inference.
Using EMA statistics in training is useful when you load a pre-trained BN and
* ``not training and ctx.is_training``: use EMA to normalize in
don't want to update it.
training. This is useful when you load a pre-trained BN and
2. As long as `training=True`, `sync_statistics` and `ema_update` option will take effect.
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
"""
ctx
=
get_current_tower_context
()
# parse shapes
# parse shapes
data_format
=
get_data_format
(
data_format
,
keras_mode
=
False
)
data_format
=
get_data_format
(
data_format
,
keras_mode
=
False
)
shape
=
inputs
.
get_shape
()
.
as_list
()
shape
=
inputs
.
get_shape
()
.
as_list
()
...
@@ -155,6 +164,23 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -155,6 +164,23 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
sync_statistics
=
sync_statistics
.
lower
()
sync_statistics
=
sync_statistics
.
lower
()
assert
sync_statistics
in
[
None
,
'nccl'
,
'horovod'
],
sync_statistics
assert
sync_statistics
in
[
None
,
'nccl'
,
'horovod'
],
sync_statistics
assert
ema_update
in
[
"default"
,
"collection"
,
"internal"
,
"skip"
]
if
internal_update
is
not
None
:
log_deprecated
(
"BatchNorm(internal_update=)"
,
"Use ema_update='internal' instead!"
,
"2020-01-01"
)
assert
ema_update
==
'default'
,
\
"Do not use internal_update and ema_update together! internal_update is deprecated"
ema_update
=
"internal"
if
internal_update
else
"collection"
if
ema_update
==
"default"
:
ema_update
=
"collection"
# Logic:
# 1. EMA update is possible only when we compute batch statistics (training=True)
# 2. We know that in training, non-main training tower does not need EMA update
# We don't know about what to do in prediction context, so be conservative and do the update.
# 3. User and explicit disable update by "skip".
do_ema_update
=
training
and
\
(
ctx
.
is_main_training_tower
or
not
ctx
.
is_training
)
\
and
(
ema_update
!=
"skip"
)
if
axis
is
None
:
if
axis
is
None
:
if
ndims
==
2
:
if
ndims
==
2
:
axis
=
1
axis
=
1
...
@@ -163,12 +189,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -163,12 +189,12 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
assert
axis
in
[
1
,
3
],
axis
assert
axis
in
[
1
,
3
],
axis
num_chan
=
shape
[
axis
]
num_chan
=
shape
[
axis
]
TF_version
=
get_tf_version_tuple
()
# parse training/ctx
# parse training/ctx
ctx
=
get_current_tower_context
()
if
training
is
None
:
if
training
is
None
:
training
=
ctx
.
is_training
training
=
ctx
.
is_training
training
=
bool
(
training
)
training
=
bool
(
training
)
TF_version
=
get_tf_version_tuple
()
freeze_bn_backward
=
not
training
and
ctx
.
is_training
freeze_bn_backward
=
not
training
and
ctx
.
is_training
if
freeze_bn_backward
:
if
freeze_bn_backward
:
assert
TF_version
>=
(
1
,
4
),
\
assert
TF_version
>=
(
1
,
4
),
\
...
@@ -177,12 +203,14 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -177,12 +203,14 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
logger
.
warn
(
"[BatchNorm] Using moving_mean/moving_variance in training."
)
logger
.
warn
(
"[BatchNorm] Using moving_mean/moving_variance in training."
)
# Using moving_mean/moving_variance in training, which means we
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
# loaded a pre-trained BN and only fine-tuning the affine part.
do_sync_bn
=
(
sync_statistics
is
not
None
)
and
training
if
sync_statistics
is
None
or
not
(
training
and
ctx
.
is_training
):
if
not
do_sync_bn
:
# Use the builtin layer for anything except for sync-bn
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
with
rename_get_variable
(
with
rename_get_variable
(
{
'moving_mean'
:
'mean/EMA'
,
{
'moving_mean'
:
'mean/EMA'
,
'moving_variance'
:
'variance/EMA'
}):
'moving_variance'
:
'variance/EMA'
}):
tf_args
=
dict
(
tf_args
=
dict
(
axis
=
axis
,
axis
=
axis
,
momentum
=
momentum
,
epsilon
=
epsilon
,
momentum
=
momentum
,
epsilon
=
epsilon
,
...
@@ -204,16 +232,17 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -204,16 +232,17 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
layer
=
tf
.
layers
.
BatchNormalization
(
**
tf_args
)
layer
=
tf
.
layers
.
BatchNormalization
(
**
tf_args
)
xn
=
layer
.
apply
(
inputs
,
training
=
training
,
scope
=
tf
.
get_variable_scope
())
xn
=
layer
.
apply
(
inputs
,
training
=
training
,
scope
=
tf
.
get_variable_scope
())
# maintain EMA only on one GPU is OK, even in replicated mode.
# Add EMA variables to the correct collection
# because during training, EMA isn't used
if
ctx
.
is_main_training_tower
:
if
ctx
.
is_main_training_tower
:
for
v
in
layer
.
non_trainable_variables
:
for
v
in
layer
.
non_trainable_variables
:
if
isinstance
(
v
,
tf
.
Variable
):
if
isinstance
(
v
,
tf
.
Variable
):
tf
.
add_to_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
,
v
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
,
v
)
if
not
ctx
.
is_main_training_tower
or
internal_update
:
restore_collection
(
coll_bk
)
if
training
and
internal_update
:
if
not
do_ema_update
:
restore_collection
(
coll_bk
)
if
do_ema_update
and
ema_update
==
"internal"
:
# Implement "internal" update.
restore_collection
(
coll_bk
)
assert
layer
.
updates
assert
layer
.
updates
with
tf
.
control_dependencies
(
layer
.
updates
):
with
tf
.
control_dependencies
(
layer
.
updates
):
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
...
@@ -301,8 +330,8 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -301,8 +330,8 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
inputs
,
batch_mean
,
batch_var
,
inputs
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
)
beta
,
gamma
,
epsilon
)
if
ctx
.
is_main_training_tower
:
if
do_ema_update
:
ret
=
update_bn_ema
(
ret
=
internal_
update_bn_ema
(
xn
,
batch_mean_vec
,
batch_var_vec
,
moving_mean
,
moving_var
,
momentum
)
xn
,
batch_mean_vec
,
batch_var_vec
,
moving_mean
,
moving_var
,
momentum
)
else
:
else
:
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment