Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
013565d6
Commit
013565d6
authored
Apr 19, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use tf.layers.BatchNormalization for implementation (#627)
parent
edac0543
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
229 additions
and
108 deletions
+229
-108
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+1
-4
examples/GAN/README.md
examples/GAN/README.md
+0
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+6
-3
tensorpack/models/_old_batch_norm.py
tensorpack/models/_old_batch_norm.py
+170
-0
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+46
-99
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+6
-1
No files found.
examples/GAN/Image2Image.py
View file @
013565d6
...
...
@@ -24,9 +24,6 @@ To train Image-to-Image translation model with image pairs:
# you can download some data from the original authors:
# https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
Speed:
On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s)
Training visualization will appear be in tensorboard.
To visualize on test set:
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model
...
...
@@ -71,7 +68,7 @@ class Model(GANModelDesc):
def
generator
(
self
,
imgs
):
# imgs: input: 256x256xch
# U-Net structure, it's slightly different from the original on the location of relu/lrelu
with
argscope
(
BatchNorm
,
use_local_stat
=
True
),
\
with
argscope
(
BatchNorm
,
training
=
True
),
\
argscope
(
Dropout
,
is_training
=
True
):
# always use local stat for BN, and apply dropout even in testing
with
argscope
(
Conv2D
,
kernel_size
=
4
,
strides
=
2
,
activation
=
BNLReLU
):
...
...
examples/GAN/README.md
View file @
013565d6
...
...
@@ -20,7 +20,6 @@ Reproduce the following GAN-related methods, 100~200 lines each:
+
CycleGAN (
[
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
](
https://arxiv.org/abs/1703.10593
)
)
Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported.
## [DCGAN.py](DCGAN.py)
...
...
tensorpack/callbacks/inference_runner.py
View file @
013565d6
...
...
@@ -56,11 +56,14 @@ def _inference_context():
class
InferenceRunnerBase
(
Callback
):
""" Base class for inference runner.
Please note that InferenceRunner will use `input.size()` to determine
Note:
1. InferenceRunner will use `input.size()` to determine
how much iterations to run, so you're responsible to ensure that
`size()` is
accurat
e.
`size()` is
reasonabl
e.
Also, InferenceRunner assumes that `trainer.model` exists
.
2. Only works with instances of `TowerTrainer`
.
"""
def
__init__
(
self
,
input
,
infs
):
"""
...
...
tensorpack/models/_old_batch_norm.py
0 → 100644
View file @
013565d6
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: _old_batch_norm.py
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.python.training
import
moving_averages
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.common
import
get_tf_version_number
from
.common
import
layer_register
,
VariableHolder
from
.tflayer
import
convert_to_tflayer_args
"""
Old Custom BN Implementation, Kept Here For Future Reference
"""
def
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
gamma_init
):
if
use_bias
:
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
else
:
beta
=
tf
.
zeros
([
n_out
],
name
=
'beta'
)
if
use_scale
:
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
gamma_init
)
else
:
gamma
=
tf
.
ones
([
n_out
],
name
=
'gamma'
)
# x * gamma + beta
moving_mean
=
tf
.
get_variable
(
'mean/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
),
trainable
=
False
)
return
beta
,
gamma
,
moving_mean
,
moving_var
def
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
,
internal_update
):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1
=
moving_averages
.
assign_moving_average
(
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
name
=
'mean_ema_op'
)
update_op2
=
moving_averages
.
assign_moving_average
(
moving_var
,
batch_var
,
decay
,
zero_debias
=
False
,
name
=
'var_ema_op'
)
if
internal_update
:
with
tf
.
control_dependencies
([
update_op1
,
update_op2
]):
return
tf
.
identity
(
xn
,
name
=
'output'
)
else
:
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
update_op1
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
update_op2
)
return
tf
.
identity
(
xn
,
name
=
'output'
)
@
layer_register
()
@
convert_to_tflayer_args
(
args_names
=
[],
name_mapping
=
{
'use_bias'
:
'center'
,
'use_scale'
:
'scale'
,
'gamma_init'
:
'gamma_initializer'
,
'decay'
:
'momentum'
,
'use_local_stat'
:
'training'
})
def
BatchNorm
(
inputs
,
training
=
None
,
momentum
=
0.9
,
epsilon
=
1e-5
,
center
=
True
,
scale
=
True
,
gamma_initializer
=
tf
.
ones_initializer
(),
data_format
=
'channels_last'
,
internal_update
=
False
):
"""
Mostly equivalent to `tf.layers.batch_normalization`, but difference in
the following:
1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored.
2. Default value for `momentum` and `epsilon` is different.
3. Default value for `training` is automatically obtained from `TowerContext`.
4. Support the `internal_update` option.
Args:
internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
by control dependencies.
Variable Names:
* ``beta``: the bias term. Will be zero-inited by default.
* ``gamma``: the scale term. Will be one-inited by default. Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean.
* ``variance/EMA``: the moving average of variance.
Note:
1. About multi-GPU training: moving averages across GPUs are not aggregated.
Batch statistics are computed independently. This is consistent with most frameworks.
2. Combinations of ``training`` and ``ctx.is_training``:
* ``training == ctx.is_training``: standard BN, EMA are
maintained during training and used during inference. This is
the default.
* ``training and not ctx.is_training``: still use batch statistics in inference.
* ``not training and ctx.is_training``: use EMA to normalize in
training. This is useful when you load a pre-trained BN and
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
inputs
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
if
ndims
==
2
:
data_format
=
'NHWC'
if
data_format
==
'NCHW'
:
n_out
=
shape
[
1
]
else
:
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
,
"Input to BatchNorm cannot have unknown channels!"
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
scale
,
center
,
gamma_initializer
)
ctx
=
get_current_tower_context
()
use_local_stat
=
training
if
use_local_stat
is
None
:
use_local_stat
=
ctx
.
is_training
use_local_stat
=
bool
(
use_local_stat
)
if
use_local_stat
:
if
ndims
==
2
:
inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
1
,
1
,
n_out
])
# fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
True
,
data_format
=
data_format
)
if
ndims
==
2
:
xn
=
tf
.
squeeze
(
xn
,
[
1
,
2
])
else
:
if
ctx
.
is_training
:
assert
get_tf_version_number
()
>=
1.4
,
\
"Fine tuning a BatchNorm model with fixed statistics is only "
\
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if
ctx
.
is_main_training_tower
:
# only warn in first tower
logger
.
warn
(
"[BatchNorm] Using moving_mean/moving_variance in training."
)
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
xn
,
_
,
_
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
mean
=
moving_mean
,
variance
=
moving_var
,
epsilon
=
epsilon
,
data_format
=
data_format
,
is_training
=
False
)
else
:
if
ndims
==
4
:
xn
,
_
,
_
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
mean
=
moving_mean
,
variance
=
moving_var
,
epsilon
=
epsilon
,
data_format
=
data_format
,
is_training
=
False
)
else
:
# avoid the reshape if possible (when channel is the last dimension)
xn
=
tf
.
nn
.
batch_normalization
(
inputs
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
if
ctx
.
is_main_training_tower
:
add_model_variable
(
moving_mean
)
add_model_variable
(
moving_var
)
if
ctx
.
is_main_training_tower
and
use_local_stat
:
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
momentum
,
internal_update
)
else
:
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
vh
=
ret
.
variables
=
VariableHolder
(
mean
=
moving_mean
,
variance
=
moving_var
)
if
scale
:
vh
.
gamma
=
gamma
if
center
:
vh
.
beta
=
beta
return
ret
tensorpack/models/batch_norm.py
View file @
013565d6
...
...
@@ -5,7 +5,6 @@
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.python.training
import
moving_averages
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
...
...
@@ -13,7 +12,7 @@ from ..tfutils.tower import get_current_tower_context
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
.common
import
layer_register
,
VariableHolder
from
.tflayer
import
convert_to_tflayer_args
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
__all__
=
[
'BatchNorm'
,
'BatchRenorm'
]
...
...
@@ -21,51 +20,6 @@ __all__ = ['BatchNorm', 'BatchRenorm']
# eps: torch: 1e-5. Lasagne: 1e-4
def
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
gamma_init
):
if
use_bias
:
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
else
:
beta
=
tf
.
zeros
([
n_out
],
name
=
'beta'
)
if
use_scale
:
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
gamma_init
)
else
:
gamma
=
tf
.
ones
([
n_out
],
name
=
'gamma'
)
# x * gamma + beta
moving_mean
=
tf
.
get_variable
(
'mean/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
),
trainable
=
False
)
return
beta
,
gamma
,
moving_mean
,
moving_var
def
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
,
internal_update
):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1
=
moving_averages
.
assign_moving_average
(
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
name
=
'mean_ema_op'
)
update_op2
=
moving_averages
.
assign_moving_average
(
moving_var
,
batch_var
,
decay
,
zero_debias
=
False
,
name
=
'var_ema_op'
)
if
internal_update
:
with
tf
.
control_dependencies
([
update_op1
,
update_op2
]):
return
tf
.
identity
(
xn
,
name
=
'output'
)
else
:
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
update_op1
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
update_op2
)
return
tf
.
identity
(
xn
,
name
=
'output'
)
def
reshape_for_bn
(
param
,
ndims
,
chan
,
data_format
):
if
ndims
==
2
:
shape
=
[
1
,
chan
]
else
:
shape
=
[
1
,
1
,
1
,
chan
]
if
data_format
==
'NHWC'
else
[
1
,
chan
,
1
,
1
]
return
tf
.
reshape
(
param
,
shape
)
@
layer_register
()
@
convert_to_tflayer_args
(
args_names
=
[],
...
...
@@ -82,7 +36,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
data_format
=
'channels_last'
,
internal_update
=
False
):
"""
Mostly equivalent to `tf.layers.batch_normalization`, but differen
ce
in
Mostly equivalent to `tf.layers.batch_normalization`, but differen
t
in
the following:
1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored.
...
...
@@ -115,38 +69,23 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
# parse shapes
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
inputs
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
assert
ndims
in
[
2
,
4
]
,
ndims
if
ndims
==
2
:
data_format
=
'NHWC'
if
data_format
==
'NCHW'
:
n_out
=
shape
[
1
]
axis
=
1
else
:
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
,
"Input to BatchNorm cannot have unknown channels!"
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
scale
,
center
,
gamma_initializer
)
axis
=
1
if
data_format
==
'NCHW'
else
3
# parse training/ctx
ctx
=
get_current_tower_context
()
use_local_stat
=
training
if
use_local_stat
is
None
:
use_local_stat
=
ctx
.
is_training
use_local_stat
=
bool
(
use_local_stat
)
if
use_local_stat
:
if
ndims
==
2
:
inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
1
,
1
,
n_out
])
# fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
True
,
data_format
=
data_format
)
if
ndims
==
2
:
xn
=
tf
.
squeeze
(
xn
,
[
1
,
2
])
else
:
if
ctx
.
is_training
:
if
training
is
None
:
training
=
ctx
.
is_training
training
=
bool
(
training
)
if
not
training
and
ctx
.
is_training
:
assert
get_tf_version_number
()
>=
1.4
,
\
"Fine tuning a BatchNorm model with fixed statistics is only "
\
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
...
...
@@ -154,36 +93,44 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
logger
.
warn
(
"[BatchNorm] Using moving_mean/moving_variance in training."
)
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
xn
,
_
,
_
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
mean
=
moving_mean
,
variance
=
moving_var
,
epsilon
=
epsilon
,
data_format
=
data_format
,
is_training
=
False
)
else
:
if
ndims
==
4
:
xn
,
_
,
_
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
mean
=
moving_mean
,
variance
=
moving_var
,
epsilon
=
epsilon
,
data_format
=
data_format
,
is_training
=
False
)
else
:
# avoid the reshape if possible (when channel is the last dimension)
xn
=
tf
.
nn
.
batch_normalization
(
inputs
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
with
rename_get_variable
(
{
'moving_mean'
:
'mean/EMA'
,
'moving_variance'
:
'variance/EMA'
}):
layer
=
tf
.
layers
.
BatchNormalization
(
axis
=
axis
,
momentum
=
momentum
,
epsilon
=
epsilon
,
center
=
center
,
scale
=
scale
,
gamma_initializer
=
gamma_initializer
,
fused
=
True
)
xn
=
layer
.
apply
(
inputs
,
training
=
training
,
scope
=
tf
.
get_variable_scope
())
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
if
ctx
.
is_main_training_tower
:
add_model_variable
(
moving_mean
)
add_model_variable
(
moving_var
)
if
ctx
.
is_main_training_tower
and
use_local_stat
:
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
momentum
,
internal_update
)
for
v
in
layer
.
non_trainable_variables
:
add_model_variable
(
v
)
if
not
ctx
.
is_main_training_tower
or
internal_update
:
restore_collection
(
coll_bk
)
if
training
and
internal_update
:
assert
layer
.
updates
with
tf
.
control_dependencies
(
layer
.
updates
):
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
else
:
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
vh
=
ret
.
variables
=
VariableHolder
(
mean
=
moving_mean
,
variance
=
moving_var
)
vh
=
ret
.
variables
=
VariableHolder
(
moving_mean
=
layer
.
moving_mean
,
mean
=
layer
.
moving_mean
,
# for backward-compatibility
moving_variance
=
layer
.
moving_variance
,
variance
=
layer
.
moving_variance
)
# for backward-compatibility
if
scale
:
vh
.
gamma
=
gamma
vh
.
gamma
=
layer
.
gamma
if
center
:
vh
.
beta
=
beta
vh
.
beta
=
layer
.
beta
return
ret
...
...
tensorpack/tfutils/summary.py
View file @
013565d6
...
...
@@ -208,6 +208,8 @@ def add_moving_summary(*args, **kwargs):
collection (str or None): the name of the collection to add EMA-maintaining ops.
The default will work together with the default
:class:`MovingAverageSummary` callback.
summary_collections ([str]): the names of collections to add the
summary op. Default is TF's default (`tf.GraphKeys.SUMMARIES`).
Returns:
[tf.Tensor]: list of tensors returned by assign_moving_average,
...
...
@@ -215,6 +217,7 @@ def add_moving_summary(*args, **kwargs):
"""
decay
=
kwargs
.
pop
(
'decay'
,
0.95
)
coll
=
kwargs
.
pop
(
'collection'
,
MOVING_SUMMARY_OPS_KEY
)
summ_coll
=
kwargs
.
pop
(
'summary_collections'
,
None
)
assert
len
(
kwargs
)
==
0
,
"Unknown arguments: "
+
str
(
kwargs
)
ctx
=
get_current_tower_context
()
...
...
@@ -248,7 +251,9 @@ def add_moving_summary(*args, **kwargs):
zero_debias
=
True
,
name
=
name
+
'_EMA_apply'
)
ema_ops
.
append
(
ema_op
)
with
tf
.
name_scope
(
None
):
tf
.
summary
.
scalar
(
name
+
'-summary'
,
ema_op
)
# write the EMA value as a summary
tf
.
summary
.
scalar
(
name
+
'-summary'
,
ema_op
,
collections
=
summ_coll
)
# write the EMA value as a summary
if
coll
is
not
None
:
for
op
in
ema_ops
:
tf
.
add_to_collection
(
coll
,
op
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment