Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
196a17f3
Commit
196a17f3
authored
Mar 06, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix some more scope issues
parent
2770ede8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
5 deletions
+13
-5
examples/FasterRCNN/basemodel.py
examples/FasterRCNN/basemodel.py
+2
-2
examples/FasterRCNN/model.py
examples/FasterRCNN/model.py
+2
-1
tensorpack/tfutils/scope_utils.py
tensorpack/tfutils/scope_utils.py
+9
-2
No files found.
examples/FasterRCNN/basemodel.py
View file @
196a17f3
...
@@ -6,6 +6,7 @@ from contextlib import contextmanager
...
@@ -6,6 +6,7 @@ from contextlib import contextmanager
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorpack.tfutils.argscope
import
argscope
,
get_arg_scope
from
tensorpack.tfutils.argscope
import
argscope
,
get_arg_scope
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.tfutils.varreplace
import
custom_getter_scope
from
tensorpack.models
import
(
from
tensorpack.models
import
(
Conv2D
,
MaxPooling
,
BatchNorm
,
BNReLU
)
Conv2D
,
MaxPooling
,
BatchNorm
,
BNReLU
)
...
@@ -26,8 +27,7 @@ def resnet_argscope():
...
@@ -26,8 +27,7 @@ def resnet_argscope():
with
argscope
([
Conv2D
,
MaxPooling
,
BatchNorm
],
data_format
=
'NCHW'
),
\
with
argscope
([
Conv2D
,
MaxPooling
,
BatchNorm
],
data_format
=
'NCHW'
),
\
argscope
(
Conv2D
,
use_bias
=
False
),
\
argscope
(
Conv2D
,
use_bias
=
False
),
\
argscope
(
BatchNorm
,
use_local_stat
=
False
),
\
argscope
(
BatchNorm
,
use_local_stat
=
False
),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
custom_getter_scope
(
maybe_freeze_affine
):
custom_getter
=
maybe_freeze_affine
):
yield
yield
...
...
examples/FasterRCNN/model.py
View file @
196a17f3
...
@@ -503,7 +503,8 @@ def maskrcnn_head(feature, num_class):
...
@@ -503,7 +503,8 @@ def maskrcnn_head(feature, num_class):
"""
"""
with
argscope
([
Conv2D
,
Deconv2D
],
data_format
=
'NCHW'
,
with
argscope
([
Conv2D
,
Deconv2D
],
data_format
=
'NCHW'
,
W_init
=
tf
.
variance_scaling_initializer
(
W_init
=
tf
.
variance_scaling_initializer
(
scale
=
2.0
,
mode
=
'fan_in'
,
distribution
=
'normal'
)):
scale
=
2.0
,
mode
=
'fan_out'
,
distribution
=
'normal'
)):
# c2's MSRAFill is fan_out
l
=
Deconv2D
(
'deconv'
,
feature
,
256
,
2
,
stride
=
2
,
nl
=
tf
.
nn
.
relu
)
l
=
Deconv2D
(
'deconv'
,
feature
,
256
,
2
,
stride
=
2
,
nl
=
tf
.
nn
.
relu
)
l
=
Conv2D
(
'conv'
,
l
,
num_class
-
1
,
1
)
l
=
Conv2D
(
'conv'
,
l
,
num_class
-
1
,
1
)
return
l
return
l
...
...
tensorpack/tfutils/scope_utils.py
View file @
196a17f3
...
@@ -8,6 +8,7 @@ import functools
...
@@ -8,6 +8,7 @@ import functools
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
..utils.argtools
import
graph_memoized
from
..utils.argtools
import
graph_memoized
from
.common
import
get_tf_version_number
__all__
=
[
'auto_reuse_variable_scope'
,
'cached_name_scope'
,
'under_name_scope'
]
__all__
=
[
'auto_reuse_variable_scope'
,
'cached_name_scope'
,
'under_name_scope'
]
...
@@ -39,7 +40,13 @@ def auto_reuse_variable_scope(func):
...
@@ -39,7 +40,13 @@ def auto_reuse_variable_scope(func):
h
=
hash
((
tf
.
get_default_graph
(),
scope
.
name
))
h
=
hash
((
tf
.
get_default_graph
(),
scope
.
name
))
# print("Entering " + scope.name + " reuse: " + str(h in used_scope))
# print("Entering " + scope.name + " reuse: " + str(h in used_scope))
if
h
in
used_scope
:
if
h
in
used_scope
:
with
tf
.
variable_scope
(
scope
,
reuse
=
True
):
if
get_tf_version_number
()
>=
1.5
:
with
tf
.
variable_scope
(
scope
,
reuse
=
True
,
auxiliary_name_scope
=
False
):
return
func
(
*
args
,
**
kwargs
)
else
:
ns
=
tf
.
get_default_graph
()
.
get_name_scope
()
with
tf
.
variable_scope
(
scope
,
reuse
=
True
),
\
tf
.
name_scope
(
ns
+
'/'
if
ns
else
''
):
return
func
(
*
args
,
**
kwargs
)
return
func
(
*
args
,
**
kwargs
)
else
:
else
:
used_scope
.
add
(
h
)
used_scope
.
add
(
h
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment