Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
013565d6
Commit
013565d6
authored
Apr 19, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use tf.layers.BatchNormalization for implementation (#627)
parent
edac0543
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
229 additions
and
108 deletions
+229
-108
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+1
-4
examples/GAN/README.md
examples/GAN/README.md
+0
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+6
-3
tensorpack/models/_old_batch_norm.py
tensorpack/models/_old_batch_norm.py
+170
-0
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+46
-99
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+6
-1
No files found.
examples/GAN/Image2Image.py
View file @
013565d6
...
@@ -24,9 +24,6 @@ To train Image-to-Image translation model with image pairs:
...
@@ -24,9 +24,6 @@ To train Image-to-Image translation model with image pairs:
# you can download some data from the original authors:
# you can download some data from the original authors:
# https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
# https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/
Speed:
On GTX1080 with BATCH=1, the speed is about 9.3it/s (the original torch version is 9.5it/s)
Training visualization will appear be in tensorboard.
Training visualization will appear be in tensorboard.
To visualize on test set:
To visualize on test set:
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load model
...
@@ -71,7 +68,7 @@ class Model(GANModelDesc):
...
@@ -71,7 +68,7 @@ class Model(GANModelDesc):
def
generator
(
self
,
imgs
):
def
generator
(
self
,
imgs
):
# imgs: input: 256x256xch
# imgs: input: 256x256xch
# U-Net structure, it's slightly different from the original on the location of relu/lrelu
# U-Net structure, it's slightly different from the original on the location of relu/lrelu
with
argscope
(
BatchNorm
,
use_local_stat
=
True
),
\
with
argscope
(
BatchNorm
,
training
=
True
),
\
argscope
(
Dropout
,
is_training
=
True
):
argscope
(
Dropout
,
is_training
=
True
):
# always use local stat for BN, and apply dropout even in testing
# always use local stat for BN, and apply dropout even in testing
with
argscope
(
Conv2D
,
kernel_size
=
4
,
strides
=
2
,
activation
=
BNLReLU
):
with
argscope
(
Conv2D
,
kernel_size
=
4
,
strides
=
2
,
activation
=
BNLReLU
):
...
...
examples/GAN/README.md
View file @
013565d6
...
@@ -20,7 +20,6 @@ Reproduce the following GAN-related methods, 100~200 lines each:
...
@@ -20,7 +20,6 @@ Reproduce the following GAN-related methods, 100~200 lines each:
+
CycleGAN (
[
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
](
https://arxiv.org/abs/1703.10593
)
)
+
CycleGAN (
[
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
](
https://arxiv.org/abs/1703.10593
)
)
Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported.
Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported.
## [DCGAN.py](DCGAN.py)
## [DCGAN.py](DCGAN.py)
...
...
tensorpack/callbacks/inference_runner.py
View file @
013565d6
...
@@ -56,11 +56,14 @@ def _inference_context():
...
@@ -56,11 +56,14 @@ def _inference_context():
class
InferenceRunnerBase
(
Callback
):
class
InferenceRunnerBase
(
Callback
):
""" Base class for inference runner.
""" Base class for inference runner.
Please note that InferenceRunner will use `input.size()` to determine
Note:
1. InferenceRunner will use `input.size()` to determine
how much iterations to run, so you're responsible to ensure that
how much iterations to run, so you're responsible to ensure that
`size()` is
accurat
e.
`size()` is
reasonabl
e.
Also, InferenceRunner assumes that `trainer.model` exists
.
2. Only works with instances of `TowerTrainer`
.
"""
"""
def
__init__
(
self
,
input
,
infs
):
def
__init__
(
self
,
input
,
infs
):
"""
"""
...
...
tensorpack/models/_old_batch_norm.py
0 → 100644
View file @
013565d6
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: _old_batch_norm.py
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.python.training
import
moving_averages
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.common
import
get_tf_version_number
from
.common
import
layer_register
,
VariableHolder
from
.tflayer
import
convert_to_tflayer_args
"""
Old Custom BN Implementation, Kept Here For Future Reference
"""
def
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
gamma_init
):
if
use_bias
:
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
else
:
beta
=
tf
.
zeros
([
n_out
],
name
=
'beta'
)
if
use_scale
:
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
gamma_init
)
else
:
gamma
=
tf
.
ones
([
n_out
],
name
=
'gamma'
)
# x * gamma + beta
moving_mean
=
tf
.
get_variable
(
'mean/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
),
trainable
=
False
)
return
beta
,
gamma
,
moving_mean
,
moving_var
def
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
,
internal_update
):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1
=
moving_averages
.
assign_moving_average
(
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
name
=
'mean_ema_op'
)
update_op2
=
moving_averages
.
assign_moving_average
(
moving_var
,
batch_var
,
decay
,
zero_debias
=
False
,
name
=
'var_ema_op'
)
if
internal_update
:
with
tf
.
control_dependencies
([
update_op1
,
update_op2
]):
return
tf
.
identity
(
xn
,
name
=
'output'
)
else
:
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
update_op1
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
update_op2
)
return
tf
.
identity
(
xn
,
name
=
'output'
)
@
layer_register
()
@
convert_to_tflayer_args
(
args_names
=
[],
name_mapping
=
{
'use_bias'
:
'center'
,
'use_scale'
:
'scale'
,
'gamma_init'
:
'gamma_initializer'
,
'decay'
:
'momentum'
,
'use_local_stat'
:
'training'
})
def
BatchNorm
(
inputs
,
training
=
None
,
momentum
=
0.9
,
epsilon
=
1e-5
,
center
=
True
,
scale
=
True
,
gamma_initializer
=
tf
.
ones_initializer
(),
data_format
=
'channels_last'
,
internal_update
=
False
):
"""
Mostly equivalent to `tf.layers.batch_normalization`, but difference in
the following:
1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored.
2. Default value for `momentum` and `epsilon` is different.
3. Default value for `training` is automatically obtained from `TowerContext`.
4. Support the `internal_update` option.
Args:
internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
by control dependencies.
Variable Names:
* ``beta``: the bias term. Will be zero-inited by default.
* ``gamma``: the scale term. Will be one-inited by default. Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean.
* ``variance/EMA``: the moving average of variance.
Note:
1. About multi-GPU training: moving averages across GPUs are not aggregated.
Batch statistics are computed independently. This is consistent with most frameworks.
2. Combinations of ``training`` and ``ctx.is_training``:
* ``training == ctx.is_training``: standard BN, EMA are
maintained during training and used during inference. This is
the default.
* ``training and not ctx.is_training``: still use batch statistics in inference.
* ``not training and ctx.is_training``: use EMA to normalize in
training. This is useful when you load a pre-trained BN and
don't want to fine tune the EMA. EMA will not be updated in
this case.
"""
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
inputs
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
if
ndims
==
2
:
data_format
=
'NHWC'
if
data_format
==
'NCHW'
:
n_out
=
shape
[
1
]
else
:
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
,
"Input to BatchNorm cannot have unknown channels!"
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
scale
,
center
,
gamma_initializer
)
ctx
=
get_current_tower_context
()
use_local_stat
=
training
if
use_local_stat
is
None
:
use_local_stat
=
ctx
.
is_training
use_local_stat
=
bool
(
use_local_stat
)
if
use_local_stat
:
if
ndims
==
2
:
inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
1
,
1
,
n_out
])
# fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
True
,
data_format
=
data_format
)
if
ndims
==
2
:
xn
=
tf
.
squeeze
(
xn
,
[
1
,
2
])
else
:
if
ctx
.
is_training
:
assert
get_tf_version_number
()
>=
1.4
,
\
"Fine tuning a BatchNorm model with fixed statistics is only "
\
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if
ctx
.
is_main_training_tower
:
# only warn in first tower
logger
.
warn
(
"[BatchNorm] Using moving_mean/moving_variance in training."
)
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
xn
,
_
,
_
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
mean
=
moving_mean
,
variance
=
moving_var
,
epsilon
=
epsilon
,
data_format
=
data_format
,
is_training
=
False
)
else
:
if
ndims
==
4
:
xn
,
_
,
_
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
mean
=
moving_mean
,
variance
=
moving_var
,
epsilon
=
epsilon
,
data_format
=
data_format
,
is_training
=
False
)
else
:
# avoid the reshape if possible (when channel is the last dimension)
xn
=
tf
.
nn
.
batch_normalization
(
inputs
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
if
ctx
.
is_main_training_tower
:
add_model_variable
(
moving_mean
)
add_model_variable
(
moving_var
)
if
ctx
.
is_main_training_tower
and
use_local_stat
:
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
momentum
,
internal_update
)
else
:
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
vh
=
ret
.
variables
=
VariableHolder
(
mean
=
moving_mean
,
variance
=
moving_var
)
if
scale
:
vh
.
gamma
=
gamma
if
center
:
vh
.
beta
=
beta
return
ret
tensorpack/models/batch_norm.py
View file @
013565d6
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.python.training
import
moving_averages
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
from
..utils.argtools
import
get_data_format
...
@@ -13,7 +12,7 @@ from ..tfutils.tower import get_current_tower_context
...
@@ -13,7 +12,7 @@ from ..tfutils.tower import get_current_tower_context
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
.common
import
layer_register
,
VariableHolder
from
.common
import
layer_register
,
VariableHolder
from
.tflayer
import
convert_to_tflayer_args
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
__all__
=
[
'BatchNorm'
,
'BatchRenorm'
]
__all__
=
[
'BatchNorm'
,
'BatchRenorm'
]
...
@@ -21,51 +20,6 @@ __all__ = ['BatchNorm', 'BatchRenorm']
...
@@ -21,51 +20,6 @@ __all__ = ['BatchNorm', 'BatchRenorm']
# eps: torch: 1e-5. Lasagne: 1e-4
# eps: torch: 1e-5. Lasagne: 1e-4
def
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
gamma_init
):
if
use_bias
:
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
else
:
beta
=
tf
.
zeros
([
n_out
],
name
=
'beta'
)
if
use_scale
:
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
gamma_init
)
else
:
gamma
=
tf
.
ones
([
n_out
],
name
=
'gamma'
)
# x * gamma + beta
moving_mean
=
tf
.
get_variable
(
'mean/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
),
trainable
=
False
)
return
beta
,
gamma
,
moving_mean
,
moving_var
def
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
,
internal_update
):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1
=
moving_averages
.
assign_moving_average
(
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
name
=
'mean_ema_op'
)
update_op2
=
moving_averages
.
assign_moving_average
(
moving_var
,
batch_var
,
decay
,
zero_debias
=
False
,
name
=
'var_ema_op'
)
if
internal_update
:
with
tf
.
control_dependencies
([
update_op1
,
update_op2
]):
return
tf
.
identity
(
xn
,
name
=
'output'
)
else
:
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
update_op1
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
update_op2
)
return
tf
.
identity
(
xn
,
name
=
'output'
)
def
reshape_for_bn
(
param
,
ndims
,
chan
,
data_format
):
if
ndims
==
2
:
shape
=
[
1
,
chan
]
else
:
shape
=
[
1
,
1
,
1
,
chan
]
if
data_format
==
'NHWC'
else
[
1
,
chan
,
1
,
1
]
return
tf
.
reshape
(
param
,
shape
)
@
layer_register
()
@
layer_register
()
@
convert_to_tflayer_args
(
@
convert_to_tflayer_args
(
args_names
=
[],
args_names
=
[],
...
@@ -82,7 +36,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -82,7 +36,7 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
data_format
=
'channels_last'
,
data_format
=
'channels_last'
,
internal_update
=
False
):
internal_update
=
False
):
"""
"""
Mostly equivalent to `tf.layers.batch_normalization`, but differen
ce
in
Mostly equivalent to `tf.layers.batch_normalization`, but differen
t
in
the following:
the following:
1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored.
1. Accepts `data_format` rather than `axis`. For 2D input, this argument will be ignored.
...
@@ -115,38 +69,23 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -115,38 +69,23 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
don't want to fine tune the EMA. EMA will not be updated in
don't want to fine tune the EMA. EMA will not be updated in
this case.
this case.
"""
"""
# parse shapes
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
data_format
=
get_data_format
(
data_format
,
tfmode
=
False
)
shape
=
inputs
.
get_shape
()
.
as_list
()
shape
=
inputs
.
get_shape
()
.
as_list
()
ndims
=
len
(
shape
)
ndims
=
len
(
shape
)
assert
ndims
in
[
2
,
4
]
assert
ndims
in
[
2
,
4
]
,
ndims
if
ndims
==
2
:
if
ndims
==
2
:
data_format
=
'NHWC'
data_format
=
'NHWC'
if
data_format
==
'NCHW'
:
axis
=
1
n_out
=
shape
[
1
]
else
:
else
:
n_out
=
shape
[
-
1
]
# channel
axis
=
1
if
data_format
==
'NCHW'
else
3
assert
n_out
is
not
None
,
"Input to BatchNorm cannot have unknown channels!"
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
scale
,
center
,
gamma_initializer
)
# parse training/ctx
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
use_local_stat
=
training
if
training
is
None
:
if
use_local_stat
is
None
:
training
=
ctx
.
is_training
use_local_stat
=
ctx
.
is_training
training
=
bool
(
training
)
use_local_stat
=
bool
(
use_local_stat
)
if
not
training
and
ctx
.
is_training
:
if
use_local_stat
:
if
ndims
==
2
:
inputs
=
tf
.
reshape
(
inputs
,
[
-
1
,
1
,
1
,
n_out
])
# fused_bn only takes 4D input
# fused_bn has error using NCHW? (see #190)
xn
,
batch_mean
,
batch_var
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
epsilon
=
epsilon
,
is_training
=
True
,
data_format
=
data_format
)
if
ndims
==
2
:
xn
=
tf
.
squeeze
(
xn
,
[
1
,
2
])
else
:
if
ctx
.
is_training
:
assert
get_tf_version_number
()
>=
1.4
,
\
assert
get_tf_version_number
()
>=
1.4
,
\
"Fine tuning a BatchNorm model with fixed statistics is only "
\
"Fine tuning a BatchNorm model with fixed statistics is only "
\
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
...
@@ -154,36 +93,44 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -154,36 +93,44 @@ def BatchNorm(inputs, training=None, momentum=0.9, epsilon=1e-5,
logger
.
warn
(
"[BatchNorm] Using moving_mean/moving_variance in training."
)
logger
.
warn
(
"[BatchNorm] Using moving_mean/moving_variance in training."
)
# Using moving_mean/moving_variance in training, which means we
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
# loaded a pre-trained BN and only fine-tuning the affine part.
xn
,
_
,
_
=
tf
.
nn
.
fused_batch_norm
(
inputs
,
gamma
,
beta
,
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
mean
=
moving_mean
,
variance
=
moving_var
,
epsilon
=
epsilon
,
with
rename_get_variable
(
data_format
=
data_format
,
is_training
=
False
)
{
'moving_mean'
:
'mean/EMA'
,
else
:
'moving_variance'
:
'variance/EMA'
}):
if
ndims
==
4
:
layer
=
tf
.
layers
.
BatchNormalization
(
xn
,
_
,
_
=
tf
.
nn
.
fused_batch_norm
(
axis
=
axis
,
inputs
,
gamma
,
beta
,
momentum
=
momentum
,
epsilon
=
epsilon
,
mean
=
moving_mean
,
variance
=
moving_var
,
epsilon
=
epsilon
,
center
=
center
,
scale
=
scale
,
data_format
=
data_format
,
is_training
=
False
)
gamma_initializer
=
gamma_initializer
,
else
:
fused
=
True
# avoid the reshape if possible (when channel is the last dimension)
)
xn
=
tf
.
nn
.
batch_normalization
(
xn
=
layer
.
apply
(
inputs
,
training
=
training
,
scope
=
tf
.
get_variable_scope
())
inputs
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
# maintain EMA only on one GPU is OK, even in replicated mode.
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
# because training time doesn't use EMA
if
ctx
.
is_main_training_tower
:
if
ctx
.
is_main_training_tower
:
add_model_variable
(
moving_mean
)
for
v
in
layer
.
non_trainable_variables
:
add_model_variable
(
moving_var
)
add_model_variable
(
v
)
if
ctx
.
is_main_training_tower
and
use_local_stat
:
if
not
ctx
.
is_main_training_tower
or
internal_update
:
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
momentum
,
internal_update
)
restore_collection
(
coll_bk
)
if
training
and
internal_update
:
assert
layer
.
updates
with
tf
.
control_dependencies
(
layer
.
updates
):
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
else
:
else
:
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
vh
=
ret
.
variables
=
VariableHolder
(
mean
=
moving_mean
,
variance
=
moving_var
)
vh
=
ret
.
variables
=
VariableHolder
(
moving_mean
=
layer
.
moving_mean
,
mean
=
layer
.
moving_mean
,
# for backward-compatibility
moving_variance
=
layer
.
moving_variance
,
variance
=
layer
.
moving_variance
)
# for backward-compatibility
if
scale
:
if
scale
:
vh
.
gamma
=
gamma
vh
.
gamma
=
layer
.
gamma
if
center
:
if
center
:
vh
.
beta
=
beta
vh
.
beta
=
layer
.
beta
return
ret
return
ret
...
...
tensorpack/tfutils/summary.py
View file @
013565d6
...
@@ -208,6 +208,8 @@ def add_moving_summary(*args, **kwargs):
...
@@ -208,6 +208,8 @@ def add_moving_summary(*args, **kwargs):
collection (str or None): the name of the collection to add EMA-maintaining ops.
collection (str or None): the name of the collection to add EMA-maintaining ops.
The default will work together with the default
The default will work together with the default
:class:`MovingAverageSummary` callback.
:class:`MovingAverageSummary` callback.
summary_collections ([str]): the names of collections to add the
summary op. Default is TF's default (`tf.GraphKeys.SUMMARIES`).
Returns:
Returns:
[tf.Tensor]: list of tensors returned by assign_moving_average,
[tf.Tensor]: list of tensors returned by assign_moving_average,
...
@@ -215,6 +217,7 @@ def add_moving_summary(*args, **kwargs):
...
@@ -215,6 +217,7 @@ def add_moving_summary(*args, **kwargs):
"""
"""
decay
=
kwargs
.
pop
(
'decay'
,
0.95
)
decay
=
kwargs
.
pop
(
'decay'
,
0.95
)
coll
=
kwargs
.
pop
(
'collection'
,
MOVING_SUMMARY_OPS_KEY
)
coll
=
kwargs
.
pop
(
'collection'
,
MOVING_SUMMARY_OPS_KEY
)
summ_coll
=
kwargs
.
pop
(
'summary_collections'
,
None
)
assert
len
(
kwargs
)
==
0
,
"Unknown arguments: "
+
str
(
kwargs
)
assert
len
(
kwargs
)
==
0
,
"Unknown arguments: "
+
str
(
kwargs
)
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
...
@@ -248,7 +251,9 @@ def add_moving_summary(*args, **kwargs):
...
@@ -248,7 +251,9 @@ def add_moving_summary(*args, **kwargs):
zero_debias
=
True
,
name
=
name
+
'_EMA_apply'
)
zero_debias
=
True
,
name
=
name
+
'_EMA_apply'
)
ema_ops
.
append
(
ema_op
)
ema_ops
.
append
(
ema_op
)
with
tf
.
name_scope
(
None
):
with
tf
.
name_scope
(
None
):
tf
.
summary
.
scalar
(
name
+
'-summary'
,
ema_op
)
# write the EMA value as a summary
tf
.
summary
.
scalar
(
name
+
'-summary'
,
ema_op
,
collections
=
summ_coll
)
# write the EMA value as a summary
if
coll
is
not
None
:
if
coll
is
not
None
:
for
op
in
ema_ops
:
for
op
in
ema_ops
:
tf
.
add_to_collection
(
coll
,
op
)
tf
.
add_to_collection
(
coll
,
op
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment