Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e1f9cc09
Commit
e1f9cc09
authored
Jul 14, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
let DataParallelInference use PredictorFactory
parent
3951aaf7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
27 deletions
+17
-27
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+17
-27
No files found.
tensorpack/callbacks/inference_runner.py
View file @
e1f9cc09
...
@@ -7,6 +7,7 @@ import tensorflow as tf
...
@@ -7,6 +7,7 @@ import tensorflow as tf
from
tensorflow.python.training.monitored_session
\
from
tensorflow.python.training.monitored_session
\
import
_HookedSession
as
HookedSession
import
_HookedSession
as
HookedSession
import
itertools
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
tqdm
import
tqdm
import
six
import
six
...
@@ -15,13 +16,11 @@ from six.moves import range
...
@@ -15,13 +16,11 @@ from six.moves import range
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils.develop
import
deprecated
from
..utils.develop
import
deprecated
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
from
..graph_builder.input_source_base
import
InputSource
from
..graph_builder.input_source_base
import
InputSource
from
..graph_builder.input_source
import
(
from
..graph_builder.input_source
import
(
FeedInput
,
DataParallelFeedInput
,
FeedfreeInput
,
TensorInput
)
FeedInput
,
DataParallelFeedInput
,
FeedfreeInput
,
TensorInput
)
from
..predict
import
PredictorTowerBuilder
from
.base
import
Callback
from
.base
import
Callback
from
.inference
import
Inferencer
from
.inference
import
Inferencer
...
@@ -88,11 +87,12 @@ class InferenceRunnerBase(Callback):
...
@@ -88,11 +87,12 @@ class InferenceRunnerBase(Callback):
self
.
_extra_hooks
=
extra_hooks
self
.
_extra_hooks
=
extra_hooks
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
# Use predict_tower in train config. either gpuid or -1
# Use predict_tower in train config. either gpuid or -1
tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
device
=
'/gpu:{}'
.
format
(
tower_id
)
if
tower_id
>=
0
else
'/cpu:0'
device
=
'/gpu:{}'
.
format
(
tower_id
)
if
tower_id
>=
0
else
'/cpu:0'
tower_name
=
TowerContext
.
get_predict_tower_name
(
tower_id
,
prefix
=
self
.
_prefix
)
tower_name
=
TowerContext
.
get_predict_tower_name
(
tower_id
,
prefix
=
self
.
_prefix
)
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
self
.
_tower_handle
=
self
.
trainer
.
predictor_factory
.
build
(
tower_name
,
device
,
self
.
_input_source
)
self
.
_tower_handle
=
self
.
trainer
.
predictor_factory
.
build
(
tower_name
,
device
,
self
.
_input_source
)
...
@@ -172,18 +172,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -172,18 +172,15 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self
.
_gpus
=
gpus
self
.
_gpus
=
gpus
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
model
=
self
.
trainer
.
model
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
self
.
_input_source
.
setup
(
model
.
get_inputs_desc
())
self
.
_handles
=
[]
# build graph
def
build_tower
(
k
):
# inputs (placeholders) for this tower only
model
.
build_graph
(
self
.
_input_source
)
builder
=
PredictorTowerBuilder
(
build_tower
,
prefix
=
self
.
_prefix
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
for
t
in
self
.
_gpus
:
for
t
in
self
.
_gpus
:
builder
.
build
(
t
)
tower_name
=
TowerContext
.
get_predict_tower_name
(
t
,
prefix
=
self
.
_prefix
)
device
=
'/gpu:{}'
.
format
(
t
)
self
.
_handles
.
append
(
self
.
trainer
.
predictor_factory
.
build
(
tower_name
,
device
,
self
.
_input_source
))
# setup feeds and hooks
# setup feeds and hooks
self
.
_hooks_parallel
=
[
self
.
_build_hook_parallel
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks_parallel
=
[
self
.
_build_hook_parallel
(
inf
)
for
inf
in
self
.
infs
]
...
@@ -191,15 +188,12 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -191,15 +188,12 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
cbs
=
self
.
_input_source
.
get_callbacks
()
cbs
=
self
.
_input_source
.
get_callbacks
()
self
.
_hooks_parallel
.
extend
([
CallbackToHook
(
cb
)
for
cb
in
cbs
])
self
.
_hooks_parallel
.
extend
([
CallbackToHook
(
cb
)
for
cb
in
cbs
])
def
_duplicate_names_across_towers
(
self
,
names
):
ret
=
[]
for
t
in
self
.
_gpus
:
ret
.
extend
([
TowerContext
.
get_predict_tower_name
(
t
,
self
.
_prefix
)
+
'/'
+
n
for
n
in
names
])
return
ret
class
InferencerToHookDataParallel
(
InferencerToHook
):
class
InferencerToHookDataParallel
(
InferencerToHook
):
def
__init__
(
self
,
inf
,
fetches
,
size
):
def
__init__
(
self
,
inf
,
fetches
,
size
):
"""
Args:
size(int): number of tensors to fetch per tower
"""
super
(
DataParallelInferenceRunner
.
InferencerToHookDataParallel
,
self
)
.
__init__
(
inf
,
fetches
)
super
(
DataParallelInferenceRunner
.
InferencerToHookDataParallel
,
self
)
.
__init__
(
inf
,
fetches
)
assert
len
(
self
.
_fetches
)
%
size
==
0
assert
len
(
self
.
_fetches
)
%
size
==
0
self
.
_sz
=
size
self
.
_sz
=
size
...
@@ -213,16 +207,12 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -213,16 +207,12 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def
_build_hook_parallel
(
self
,
inf
):
def
_build_hook_parallel
(
self
,
inf
):
out_names
=
inf
.
get_output_tensors
()
out_names
=
inf
.
get_output_tensors
()
sz
=
len
(
out_names
)
sz
=
len
(
out_names
)
out_names
=
self
.
_duplicate_names_across_towers
(
out_names
)
fetches
=
list
(
itertools
.
chain
(
*
[
t
.
get_tensors
(
out_names
)
for
t
in
self
.
_handles
]))
fetches
=
get_tensors_by_names
(
out_names
)
return
self
.
InferencerToHookDataParallel
(
inf
,
fetches
,
sz
)
return
DataParallelInferenceRunner
.
InferencerToHookDataParallel
(
inf
,
fetches
,
sz
)
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_output_tensors
()
out_names
=
inf
.
get_output_tensors
()
names
=
[
TowerContext
.
get_predict_tower_name
(
fetches
=
self
.
_handles
[
0
]
.
get_tensors
(
out_names
)
self
.
_gpus
[
0
],
self
.
_prefix
)
+
'/'
+
n
for
n
in
out_names
]
fetches
=
get_tensors_by_names
(
names
)
return
InferencerToHook
(
inf
,
fetches
)
return
InferencerToHook
(
inf
,
fetches
)
def
_before_train
(
self
):
def
_before_train
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment