Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
43d2bffb
Commit
43d2bffb
authored
Feb 28, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use some tflayers argument names in batchnorm (#627)
parent
8e5a46a4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
19 deletions
+36
-19
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+34
-17
tensorpack/models/pool.py
tensorpack/models/pool.py
+1
-1
tensorpack/tfutils/model_utils.py
tensorpack/tfutils/model_utils.py
+1
-1
No files found.
tensorpack/models/batch_norm.py
View file @
43d2bffb
...
@@ -13,6 +13,7 @@ from ..tfutils.tower import get_current_tower_context
...
@@ -13,6 +13,7 @@ from ..tfutils.tower import get_current_tower_context
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
.common
import
layer_register
,
VariableHolder
from
.common
import
layer_register
,
VariableHolder
from
.tflayer
import
convert_to_tflayer_args
__all__
=
[
'BatchNorm'
,
'BatchRenorm'
]
__all__
=
[
'BatchNorm'
,
'BatchRenorm'
]
...
@@ -66,9 +67,17 @@ def reshape_for_bn(param, ndims, chan, data_format):
...
@@ -66,9 +67,17 @@ def reshape_for_bn(param, ndims, chan, data_format):
@
layer_register
()
@
layer_register
()
def
BatchNorm
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
,
@
convert_to_tflayer_args
(
use_scale
=
True
,
use_bias
=
True
,
args_names
=
[],
gamma_init
=
tf
.
constant_initializer
(
1.0
),
name_mapping
=
{
'use_bias'
:
'center'
,
'use_scale'
:
'scale'
,
'gamma_init'
:
'gamma_initializer'
,
'decay'
:
'momentum'
})
def
BatchNorm
(
x
,
use_local_stat
=
None
,
momentum
=
0.9
,
epsilon
=
1e-5
,
scale
=
True
,
center
=
True
,
gamma_initializer
=
tf
.
ones_initializer
(),
data_format
=
'channels_last'
,
data_format
=
'channels_last'
,
internal_update
=
False
):
internal_update
=
False
):
"""
"""
...
@@ -80,10 +89,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
...
@@ -80,10 +89,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
use_local_stat (bool): whether to use mean/var of the current batch or the moving average.
use_local_stat (bool): whether to use mean/var of the current batch or the moving average.
Defaults to True in training and False in inference.
Defaults to True in training and False in inference.
decay (float): decay rate
of moving average.
momentum (float): momentum
of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias
(bool): whether to use the extra affine transformation or not.
scale, center
(bool): whether to use the extra affine transformation or not.
gamma_init: initializer for gamma (the scale).
gamma_init
ializer
: initializer for gamma (the scale).
internal_update (bool): if False, add EMA update ops to
internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
which will be slightly slower.
which will be slightly slower.
...
@@ -122,7 +131,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
...
@@ -122,7 +131,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
else
:
else
:
n_out
=
shape
[
-
1
]
# channel
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
,
"Input to BatchNorm cannot have unknown channels!"
assert
n_out
is
not
None
,
"Input to BatchNorm cannot have unknown channels!"
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
gamma_init
)
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
scale
,
center
,
gamma_initializer
)
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
if
use_local_stat
is
None
:
if
use_local_stat
is
None
:
...
@@ -170,21 +179,29 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
...
@@ -170,21 +179,29 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
add_model_variable
(
moving_mean
)
add_model_variable
(
moving_mean
)
add_model_variable
(
moving_var
)
add_model_variable
(
moving_var
)
if
ctx
.
is_main_training_tower
and
use_local_stat
:
if
ctx
.
is_main_training_tower
and
use_local_stat
:
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
,
internal_update
)
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
momentum
,
internal_update
)
else
:
else
:
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
vh
=
ret
.
variables
=
VariableHolder
(
mean
=
moving_mean
,
variance
=
moving_var
)
vh
=
ret
.
variables
=
VariableHolder
(
mean
=
moving_mean
,
variance
=
moving_var
)
if
use_
scale
:
if
scale
:
vh
.
gamma
=
gamma
vh
.
gamma
=
gamma
if
use_bias
:
if
center
:
vh
.
beta
=
beta
vh
.
beta
=
beta
return
ret
return
ret
@
layer_register
()
@
layer_register
()
def
BatchRenorm
(
x
,
rmax
,
dmax
,
decay
=
0.9
,
epsilon
=
1e-5
,
@
convert_to_tflayer_args
(
use_scale
=
True
,
use_bias
=
True
,
gamma_init
=
None
,
args_names
=
[],
name_mapping
=
{
'use_bias'
:
'center'
,
'use_scale'
:
'scale'
,
'gamma_init'
:
'gamma_initializer'
,
'decay'
:
'momentum'
})
def
BatchRenorm
(
x
,
rmax
,
dmax
,
momentum
=
0.9
,
epsilon
=
1e-5
,
scale
=
True
,
bias
=
True
,
gamma_initializer
=
None
,
data_format
=
'channels_last'
):
data_format
=
'channels_last'
):
"""
"""
Batch Renormalization layer, as described in the paper:
Batch Renormalization layer, as described in the paper:
...
@@ -221,15 +238,15 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
...
@@ -221,15 +238,15 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
coll_bk
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
layer
=
tf
.
layers
.
BatchNormalization
(
layer
=
tf
.
layers
.
BatchNormalization
(
axis
=
1
if
data_format
==
'channels_first'
else
3
,
axis
=
1
if
data_format
==
'channels_first'
else
3
,
momentum
=
decay
,
epsilon
=
epsilon
,
momentum
=
momentum
,
epsilon
=
epsilon
,
center
=
use_bias
,
scale
=
use_
scale
,
center
=
center
,
scale
=
scale
,
renorm
=
True
,
renorm
=
True
,
renorm_clipping
=
{
renorm_clipping
=
{
'rmin'
:
1.0
/
rmax
,
'rmin'
:
1.0
/
rmax
,
'rmax'
:
rmax
,
'rmax'
:
rmax
,
'dmax'
:
dmax
},
'dmax'
:
dmax
},
renorm_momentum
=
0.99
,
renorm_momentum
=
0.99
,
gamma_initializer
=
gamma_init
,
gamma_initializer
=
gamma_init
ializer
,
fused
=
False
)
fused
=
False
)
xn
=
layer
.
apply
(
x
,
training
=
ctx
.
is_training
,
scope
=
tf
.
get_variable_scope
())
xn
=
layer
.
apply
(
x
,
training
=
ctx
.
is_training
,
scope
=
tf
.
get_variable_scope
())
...
@@ -246,8 +263,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
...
@@ -246,8 +263,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
# TODO not sure whether to add moving_mean/moving_var to VH now
# TODO not sure whether to add moving_mean/moving_var to VH now
vh
=
ret
.
variables
=
VariableHolder
()
vh
=
ret
.
variables
=
VariableHolder
()
if
use_
scale
:
if
scale
:
vh
.
gamma
=
layer
.
gamma
vh
.
gamma
=
layer
.
gamma
if
use_bias
:
if
center
:
vh
.
beta
=
layer
.
beta
vh
.
beta
=
layer
.
beta
return
ret
return
ret
tensorpack/models/pool.py
View file @
43d2bffb
...
@@ -57,7 +57,6 @@ def AvgPooling(
...
@@ -57,7 +57,6 @@ def AvgPooling(
@
layer_register
(
log_shape
=
True
)
@
layer_register
(
log_shape
=
True
)
@
convert_to_tflayer_args
(
args_names
=
[],
name_mapping
=
{})
def
GlobalAvgPooling
(
x
,
data_format
=
'channels_last'
):
def
GlobalAvgPooling
(
x
,
data_format
=
'channels_last'
):
"""
"""
Global average pooling as in the paper `Network In Network
Global average pooling as in the paper `Network In Network
...
@@ -70,6 +69,7 @@ def GlobalAvgPooling(x, data_format='channels_last'):
...
@@ -70,6 +69,7 @@ def GlobalAvgPooling(x, data_format='channels_last'):
tf.Tensor: a NC tensor named ``output``.
tf.Tensor: a NC tensor named ``output``.
"""
"""
assert
x
.
shape
.
ndims
==
4
assert
x
.
shape
.
ndims
==
4
data_format
=
get_data_format
(
data_format
)
axis
=
[
1
,
2
]
if
data_format
==
'channels_last'
else
[
2
,
3
]
axis
=
[
1
,
2
]
if
data_format
==
'channels_last'
else
[
2
,
3
]
return
tf
.
reduce_mean
(
x
,
axis
,
name
=
'output'
)
return
tf
.
reduce_mean
(
x
,
axis
,
name
=
'output'
)
...
...
tensorpack/tfutils/model_utils.py
View file @
43d2bffb
...
@@ -46,7 +46,7 @@ def describe_trainable_vars():
...
@@ -46,7 +46,7 @@ def describe_trainable_vars():
summary_msg
=
colored
(
summary_msg
=
colored
(
"
\n
Total #vars={}, #params={}, size={:.02f}MB"
.
format
(
"
\n
Total #vars={}, #params={}, size={:.02f}MB"
.
format
(
len
(
data
),
total
,
size_mb
),
'cyan'
)
len
(
data
),
total
,
size_mb
),
'cyan'
)
logger
.
info
(
colored
(
"
Model Parameter
s:
\n
"
,
'cyan'
)
+
table
+
summary_msg
)
logger
.
info
(
colored
(
"
Trainable Variable
s:
\n
"
,
'cyan'
)
+
table
+
summary_msg
)
def
get_shape_str
(
tensors
):
def
get_shape_str
(
tensors
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment