Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
70c9ba8f
Commit
70c9ba8f
authored
Feb 01, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
VariableHolder for layer_norm; Do not assert replicated variables.
parent
bf4d8938
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
5 deletions
+20
-5
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+4
-2
tensorpack/models/layer_norm.py
tensorpack/models/layer_norm.py
+16
-3
No files found.
tensorpack/graph_builder/training.py
View file @
70c9ba8f
...
@@ -256,8 +256,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
...
@@ -256,8 +256,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
logger
.
error
(
"[SyncMultiGPUReplicatedBuilder] variable "
logger
.
error
(
"[SyncMultiGPUReplicatedBuilder] variable "
"{} has its prefix {} appears multiple times in its name!"
.
format
(
v
.
name
,
prefix
))
"{} has its prefix {} appears multiple times in its name!"
.
format
(
v
.
name
,
prefix
))
copy_from
=
var_by_name
.
get
(
realname
)
copy_from
=
var_by_name
.
get
(
realname
)
assert
copy_from
is
not
None
,
var_by_name
.
keys
()
if
copy_from
is
not
None
:
post_init_ops
.
append
(
v
.
assign
(
copy_from
.
read_value
()))
post_init_ops
.
append
(
v
.
assign
(
copy_from
.
read_value
()))
else
:
logger
.
warn
(
"[ReplicatedTrainer] Cannot find {} in the graph!"
.
format
(
realname
))
logger
.
info
(
logger
.
info
(
"'sync_variables_from_main_tower' includes {} operations."
.
format
(
len
(
post_init_ops
)))
"'sync_variables_from_main_tower' includes {} operations."
.
format
(
len
(
post_init_ops
)))
return
tf
.
group
(
*
post_init_ops
,
name
=
'sync_variables_from_main_tower'
)
return
tf
.
group
(
*
post_init_ops
,
name
=
'sync_variables_from_main_tower'
)
...
...
tensorpack/models/layer_norm.py
View file @
70c9ba8f
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.common
import
layer_register
from
.common
import
layer_register
,
VariableHolder
__all__
=
[
'LayerNorm'
,
'InstanceNorm'
]
__all__
=
[
'LayerNorm'
,
'InstanceNorm'
]
...
@@ -51,7 +51,14 @@ def LayerNorm(
...
@@ -51,7 +51,14 @@ def LayerNorm(
else
:
else
:
gamma
=
tf
.
ones
([
1
]
*
ndims
,
name
=
'gamma'
)
gamma
=
tf
.
ones
([
1
]
*
ndims
,
name
=
'gamma'
)
return
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
name
=
'output'
)
ret
=
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
name
=
'output'
)
vh
=
ret
.
variables
=
VariableHolder
()
if
use_scale
:
vh
.
gamma
=
gamma
if
use_bias
:
vh
.
beta
=
beta
return
ret
@
layer_register
()
@
layer_register
()
...
@@ -90,4 +97,10 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
...
@@ -90,4 +97,10 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
gamma_init
=
tf
.
constant_initializer
(
1.0
)
gamma_init
=
tf
.
constant_initializer
(
1.0
)
gamma
=
tf
.
get_variable
(
'gamma'
,
[
ch
],
initializer
=
gamma_init
)
gamma
=
tf
.
get_variable
(
'gamma'
,
[
ch
],
initializer
=
gamma_init
)
gamma
=
tf
.
reshape
(
gamma
,
new_shape
)
gamma
=
tf
.
reshape
(
gamma
,
new_shape
)
return
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
name
=
'output'
)
ret
=
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
name
=
'output'
)
vh
=
ret
.
variables
=
VariableHolder
()
if
use_affine
:
vh
.
gamma
=
gamma
vh
.
beta
=
beta
return
ret
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment