Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fb501e66
Commit
fb501e66
authored
Jul 14, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
InferenceRunnerBase uses PredictorFactory to build graph
parent
4ee1e735
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
21 deletions
+19
-21
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+7
-13
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+1
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+1
-2
tensorpack/train/base.py
tensorpack/train/base.py
+10
-5
No files found.
tensorpack/callbacks/inference_runner.py
View file @
fb501e66
...
...
@@ -87,12 +87,11 @@ class InferenceRunnerBase(Callback):
def
_setup_graph
(
self
):
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
# Use predict_tower in train config. either gpuid or -1
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
def
fn
(
_
):
self
.
trainer
.
model
.
build_graph
(
self
.
_input_source
)
with
tf
.
variable_scope
(
self
.
trainer
.
vs_name_for_predictor
,
reuse
=
True
):
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
device
=
'/gpu:{}'
.
format
(
tower_id
)
if
tower_id
>=
0
else
'/cpu:0'
tower_name
=
TowerContext
.
get_predict_tower_name
(
tower_id
,
prefix
=
self
.
_prefix
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
self
.
_tower_handle
=
self
.
trainer
.
predictor_factory
.
build
(
tower_name
,
device
,
self
.
_input_source
)
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
cbs
=
self
.
_input_source
.
get_callbacks
()
...
...
@@ -102,11 +101,6 @@ class InferenceRunnerBase(Callback):
self
.
_hooks
.
extend
(
self
.
_extra_hooks
)
self
.
_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks
)
def
_get_tensors_maybe_in_tower
(
self
,
names
):
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
trainer
.
model
.
get_inputs_desc
()])
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
return
get_tensor_fn
(
placeholder_names
,
names
,
self
.
_predict_tower_id
,
prefix
=
self
.
_prefix
)
@
abstractmethod
def
_build_hook
(
self
,
inf
):
pass
...
...
@@ -142,7 +136,7 @@ class InferenceRunner(InferenceRunnerBase):
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_output_tensors
()
fetches
=
self
.
_
get_tensors_maybe_in_tower
(
out_names
)
fetches
=
self
.
_
tower_handle
.
get_tensors
(
out_names
)
return
InferencerToHook
(
inf
,
fetches
)
...
...
@@ -170,7 +164,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
ret
=
[]
for
name
in
out_names
:
assert
name
not
in
placeholder_names
,
"Currently inferencer don't support fetching placeholders!"
ret
.
append
(
self
.
_
get_tensors_maybe_in_tower
([
name
])[
0
])
ret
.
append
(
self
.
_
tower_handle
.
get_tensors
([
name
])[
0
])
return
InferencerToHook
(
inf
,
ret
)
...
...
tensorpack/graph_builder/predictor_factory.py
View file @
fb501e66
...
...
@@ -33,7 +33,7 @@ class PredictorTowerHandle(object):
class
PredictorFactory
(
object
):
""" Make predictors from :class:`ModelDesc`
and cache them
."""
""" Make predictors from :class:`ModelDesc`."""
def
__init__
(
self
,
model
,
towers
,
vs_name
):
"""
...
...
tensorpack/predict/base.py
View file @
fb501e66
...
...
@@ -167,7 +167,6 @@ class PredictorTowerBuilder(object):
tower (int): the tower will be built on device '/gpu:{tower}', or
'/cpu:0' if tower is -1.
"""
toweridx
=
max
(
tower
,
0
)
# if CPU, named the tower as 0
towername
=
TowerContext
.
get_predict_tower_name
(
tower
,
self
.
_prefix
)
if
self
.
_prefix
:
msg
=
"Building predictor graph {} on gpu={} with prefix='{}' ..."
.
format
(
...
...
@@ -180,7 +179,7 @@ class PredictorTowerBuilder(object):
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
TOWER_FREEZE_KEYS
),
\
tf
.
device
(
device
),
\
TowerContext
(
towername
,
is_training
=
False
,
index
=
toweridx
):
TowerContext
(
towername
,
is_training
=
False
):
self
.
_fn
(
tower
)
# useful only when the placeholders don't have tower prefix
...
...
tensorpack/train/base.py
View file @
fb501e66
...
...
@@ -211,7 +211,7 @@ class Trainer(object):
self
.
_callbacks
.
after_train
()
self
.
hooked_sess
.
close
()
# Predictor related methods:
TODO
# Predictor related methods:
@
property
def
vs_name_for_predictor
(
self
):
"""
...
...
@@ -229,16 +229,21 @@ class Trainer(object):
Returns:
an :class:`OnlinePredictor`.
"""
if
not
hasattr
(
self
,
'_predictor_factory'
):
self
.
_predictor_factory
=
PredictorFactory
(
self
.
model
,
self
.
config
.
predict_tower
,
self
.
vs_name_for_predictor
)
# TODO move the logic to factory?
nr_tower
=
len
(
self
.
config
.
predict_tower
)
if
nr_tower
<
tower
:
logger
.
warn
(
"Requested the {}th predictor but only have {} predict towers! "
"Predictors will be assigned to GPUs in round-robin."
.
format
(
tower
,
nr_tower
))
tower
=
tower
%
nr_tower
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
return
self
.
predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
@
property
def
predictor_factory
(
self
):
if
not
hasattr
(
self
,
'_predictor_factory'
):
self
.
_predictor_factory
=
PredictorFactory
(
self
.
model
,
self
.
config
.
predict_tower
,
self
.
vs_name_for_predictor
)
return
self
.
_predictor_factory
def
get_predictors
(
self
,
input_names
,
output_names
,
n
):
""" Return n predictors. """
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment