Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e66857ba
Commit
e66857ba
authored
Jan 22, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use reuse context instead of reuse_variables()
parent
e1278514
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
10 deletions
+11
-10
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+3
-4
tensorpack/predict/base.py
tensorpack/predict/base.py
+4
-3
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+4
-3
No files found.
tensorpack/callbacks/inference_runner.py
View file @
e66857ba
...
...
@@ -182,10 +182,9 @@ class FeedfreeInferenceRunner(Callback):
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# tensors
tf
.
get_variable_scope
()
.
reuse_variables
()
# overwrite the FeedfreeInferenceRunner scope
with
tf
.
name_scope
(
None
),
\
# overwrite the FeedfreeInferenceRunner name scope
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
),
\
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
def
fn
(
_
):
self
.
trainer
.
model
.
build_graph
(
self
.
_input_tensors
)
...
...
tensorpack/predict/base.py
View file @
e66857ba
...
...
@@ -152,11 +152,12 @@ def build_prediction_graph(build_tower_fn, towers=[0], prefix=''):
prefix: an extra prefix in tower name. The final tower prefix will be
determined by :meth:`TowerContext.get_predict_tower_name`.
"""
for
k
in
towers
:
for
idx
,
k
in
enumerate
(
towers
)
:
logger
.
info
(
"Building prediction graph for towerid={} with prefix='{}' ..."
.
format
(
k
,
prefix
))
towername
=
TowerContext
.
get_predict_tower_name
(
prefix
,
k
)
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
towername
,
is_training
=
False
):
TowerContext
(
towername
,
is_training
=
False
),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
if
idx
>
0
else
None
):
build_tower_fn
(
k
)
tf
.
get_variable_scope
()
.
reuse_variables
()
tensorpack/predict/multigpu.py
View file @
e66857ba
...
...
@@ -71,16 +71,17 @@ class DataParallelOfflinePredictor(OnlinePredictor):
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
input_var_names
=
[]
output_vars
=
[]
for
k
in
towers
:
for
idx
,
k
in
enumerate
(
towers
)
:
towername
=
PREDICT_TOWER
+
str
(
k
)
input_vars
=
config
.
model
.
build_placeholders
(
prefix
=
towername
+
'-'
)
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
towername
,
is_training
=
False
):
TowerContext
(
towername
,
is_training
=
False
),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
if
idx
>
0
else
None
):
config
.
model
.
build_graph
(
input_vars
)
tf
.
get_variable_scope
()
.
reuse_variables
()
input_var_names
.
extend
([
k
.
name
for
k
in
input_vars
])
output_vars
.
extend
(
get_tensors_by_names
(
[
towername
+
'/'
+
n
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment