Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
de1a5acd
Commit
de1a5acd
authored
Sep 21, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
tower specific noreuse scope
parent
faec6370
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
1 deletion
+20
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+19
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-1
No files found.
tensorpack/models/model_desc.py
View file @
de1a5acd
...
@@ -39,6 +39,25 @@ class TowerContext(object):
...
@@ -39,6 +39,25 @@ class TowerContext(object):
def
is_training
(
self
):
def
is_training
(
self
):
return
self
.
_is_training
return
self
.
_is_training
@
property
def
name
(
self
):
return
self
.
_name
def
get_variable_on_tower
(
self
,
*
args
,
**
kwargs
):
"""
Get a variable for this tower specifically, without reusing.
Tensorflow doesn't allow reuse=False scope under a
reuse=True scope. This method provides a work around.
See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope
:param args, kwargs: same as tf.get_variable()
"""
with
tf
.
variable_scope
(
self
.
_name
)
as
scope
:
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
scope
=
tf
.
get_variable_scope
()
assert
scope
.
reuse
==
False
return
tf
.
get_variable
(
*
args
,
**
kwargs
)
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
if
self
.
is_main_tower
:
if
self
.
is_main_tower
:
return
graph
.
get_tensor_by_name
(
name
)
return
graph
.
get_tensor_by_name
(
name
)
...
...
tensorpack/train/trainer.py
View file @
de1a5acd
...
@@ -46,8 +46,8 @@ class PredictorFactory(object):
...
@@ -46,8 +46,8 @@ class PredictorFactory(object):
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
def
_build_predict_tower
(
self
):
def
_build_predict_tower
(
self
):
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope
with
tf
.
name_scope
(
None
),
\
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
build_multi_tower_prediction_graph
(
build_multi_tower_prediction_graph
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment