Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b9e79e1c
Commit
b9e79e1c
authored
May 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add VariableHolder for BN
parent
0e4ddfd6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
4 deletions
+10
-4
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+10
-3
tensorpack/models/common.py
tensorpack/models/common.py
+0
-1
No files found.
tensorpack/models/batch_norm.py
View file @
b9e79e1c
...
...
@@ -9,7 +9,7 @@ from tensorflow.python.training import moving_averages
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
.common
import
layer_register
from
.common
import
layer_register
,
VariableHolder
__all__
=
[
'BatchNorm'
,
'BatchRenorm'
]
...
...
@@ -220,9 +220,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
# maintain EMA only on one GPU.
if
ctx
.
is_main_training_tower
:
ret
urn
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
)
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
)
else
:
return
tf
.
identity
(
xn
,
name
=
'output'
)
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
vh
=
ret
.
variables
=
VariableHolder
(
mean
=
moving_mean
,
variance
=
moving_var
)
if
use_scale
:
vh
.
gamma
=
gamma
if
use_bias
:
vh
.
beta
=
beta
return
ret
# TODO support NCHW
...
...
tensorpack/models/common.py
View file @
b9e79e1c
...
...
@@ -31,7 +31,6 @@ class VariableHolder(object):
self
.
_add_variable
(
k
,
v
)
def
_add_variable
(
self
,
name
,
var
):
print
(
name
,
var
.
name
)
assert
name
not
in
self
.
_vars
self
.
_vars
[
name
]
=
var
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment