Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2770ede8
Commit
2770ede8
authored
Mar 05, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix name scope after using tflayers (#627)
parent
4adbaa94
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
19 deletions
+16
-19
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+3
-17
tensorpack/tfutils/varreplace.py
tensorpack/tfutils/varreplace.py
+13
-2
No files found.
tensorpack/graph_builder/utils.py
View file @
2770ede8
...
...
@@ -7,7 +7,7 @@ from contextlib import contextmanager
import
operator
import
tensorflow
as
tf
from
..tfutils.
common
import
get_tf_version_number
from
..tfutils.
varreplace
import
custom_getter_scope
__all__
=
[
'LeastLoadedDeviceSetter'
,
...
...
@@ -42,22 +42,8 @@ def override_to_local_variable(enable=True):
_replace_global_by_local
(
kwargs
)
return
getter
(
name
,
*
args
,
**
kwargs
)
orig_vs
=
tf
.
get_variable_scope
()
if
get_tf_version_number
()
>=
1.5
:
with
tf
.
variable_scope
(
orig_vs
,
custom_getter
=
custom_getter
,
auxiliary_name_scope
=
False
):
yield
else
:
if
get_tf_version_number
()
>=
1.2
:
ns
=
tf
.
get_default_graph
()
.
get_name_scope
()
else
:
ns
=
orig_vs
.
original_name_scope
with
tf
.
variable_scope
(
orig_vs
,
custom_getter
=
custom_getter
):
with
tf
.
name_scope
(
ns
+
'/'
if
ns
else
''
):
yield
with
custom_getter_scope
(
custom_getter
):
yield
else
:
yield
...
...
tensorpack/tfutils/varreplace.py
View file @
2770ede8
...
...
@@ -6,14 +6,25 @@
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
.common
import
get_tf_version_number
__all__
=
[
'freeze_variables'
,
'remap_variables'
]
@
contextmanager
def
custom_getter_scope
(
custom_getter
):
scope
=
tf
.
get_variable_scope
()
with
tf
.
variable_scope
(
scope
,
custom_getter
=
custom_getter
):
yield
if
get_tf_version_number
()
>=
1.5
:
with
tf
.
variable_scope
(
scope
,
custom_getter
=
custom_getter
,
auxiliary_name_scope
=
False
):
yield
else
:
ns
=
tf
.
get_default_graph
()
.
get_name_scope
()
with
tf
.
variable_scope
(
scope
,
custom_getter
=
custom_getter
):
with
tf
.
name_scope
(
ns
+
'/'
if
ns
else
''
):
yield
def
remap_variables
(
fn
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment