Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6b10019e
Commit
6b10019e
authored
Dec 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Support internal_update in BN
parent
35527038
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
12 deletions
+17
-12
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+2
-3
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+15
-9
No files found.
examples/FasterRCNN/data.py
View file @
6b10019e
...
...
@@ -8,7 +8,7 @@ import copy
from
tensorpack.utils.argtools
import
memoized
,
log_once
from
tensorpack.dataflow
import
(
MapData
,
imgaug
,
TestDataSpeed
,
PrefetchDataZMQ
,
imgaug
,
TestDataSpeed
,
PrefetchDataZMQ
,
MultiProcessMapData
,
MapDataComponent
,
DataFromList
)
# import tensorpack.utils.viz as tpviz
...
...
@@ -251,8 +251,7 @@ def get_train_dataflow(add_mask=False):
# tpviz.interactive_imshow(viz)
return
ret
ds
=
MapData
(
ds
,
preprocess
)
ds
=
PrefetchDataZMQ
(
ds
,
1
)
ds
=
MultiProcessMapData
(
ds
,
3
,
preprocess
)
return
ds
...
...
tensorpack/models/batch_norm.py
View file @
6b10019e
...
...
@@ -37,7 +37,8 @@ def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
return
beta
,
gamma
,
moving_mean
,
moving_var
def
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
):
def
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
,
internal_update
):
# TODO is there a way to use zero_debias in multi-GPU?
update_op1
=
moving_averages
.
assign_moving_average
(
moving_mean
,
batch_mean
,
decay
,
zero_debias
=
False
,
...
...
@@ -46,9 +47,10 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
moving_var
,
batch_var
,
decay
,
zero_debias
=
False
,
name
=
'var_ema_op'
)
# TODO add an option, and maybe enable it for replica mode?
# with tf.control_dependencies([update_op1, update_op2]):
# return tf.identity(xn, name='output')
if
internal_update
:
with
tf
.
control_dependencies
([
update_op1
,
update_op2
]):
return
tf
.
identity
(
xn
,
name
=
'output'
)
else
:
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
update_op1
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
update_op2
)
return
xn
...
...
@@ -65,7 +67,8 @@ def reshape_for_bn(param, ndims, chan, data_format):
@
layer_register
()
def
BatchNorm
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
,
use_scale
=
True
,
use_bias
=
True
,
gamma_init
=
tf
.
constant_initializer
(
1.0
),
data_format
=
'NHWC'
):
gamma_init
=
tf
.
constant_initializer
(
1.0
),
data_format
=
'NHWC'
,
internal_update
=
False
):
"""
Batch Normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by
...
...
@@ -79,6 +82,9 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
gamma_init: initializer for gamma (the scale).
internal_update (bool): if False, add EMA update ops to
`tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
which will be slightly slower.
Returns:
tf.Tensor: a tensor named ``output`` with the same shape of x.
...
...
@@ -161,7 +167,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
add_model_variable
(
moving_mean
)
add_model_variable
(
moving_var
)
if
ctx
.
is_main_training_tower
and
use_local_stat
:
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
)
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
,
internal_update
)
else
:
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment