Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7699fd9b
Commit
7699fd9b
authored
Jul 14, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
let multi-gpu OfflinePredictor use PredictorFactory
parent
e1f9cc09
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
33 deletions
+54
-33
tensorpack/graph_builder/input_source.py
tensorpack/graph_builder/input_source.py
+23
-5
tensorpack/graph_builder/input_source_base.py
tensorpack/graph_builder/input_source_base.py
+0
-1
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+2
-2
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+29
-25
No files found.
tensorpack/graph_builder/input_source.py
View file @
7699fd9b
...
...
@@ -21,13 +21,31 @@ from ..utils import logger
from
..utils.concurrency
import
ShareSessionThread
from
..callbacks.base
import
Callback
__all__
=
[
'FeedInput'
,
'DataParallelFeedInput'
,
__all__
=
[
'
PlaceholderInput'
,
'
FeedInput'
,
'DataParallelFeedInput'
,
'FeedfreeInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'ZMQInput'
,
'DummyConstantInput'
,
'TensorInput'
,
'StagingInputWrapper'
]
class
PlaceholderInput
(
InputSource
):
"""
Just produce placeholders as input tensors.
"""
def
__init__
(
self
,
prefix
=
''
):
"""
Args:
prefix(str): an optional prefix to add to the placeholder.
"""
self
.
_prefix
=
prefix
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder
(
prefix
=
self
.
_prefix
)
for
v
in
inputs
]
def
_get_input_tensors
(
self
):
return
self
.
_all_placehdrs
class
FeedInput
(
InputSource
):
""" Input by iterating over a DataFlow and feed datapoints. """
...
...
@@ -60,16 +78,16 @@ class FeedInput(InputSource):
return
self
.
ds
.
size
()
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder
_reuse
(
)
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
v
.
build_placeholder
(
prefix
=
self
.
_prefix
)
for
v
in
inputs
]
self
.
_cb
=
self
.
_FeedCallback
(
self
.
_repeat_ds
,
self
.
_all_placehdrs
)
self
.
reset_state
()
def
_reset_state
(
self
):
self
.
_cb
.
_reset
()
def
_get_input_tensors
(
self
):
return
self
.
_all_placehdrs
def
_reset_state
(
self
):
self
.
_cb
.
_reset
()
def
_get_callbacks
(
self
):
return
[
self
.
_cb
]
...
...
tensorpack/graph_builder/input_source_base.py
View file @
7699fd9b
...
...
@@ -54,7 +54,6 @@ class InputSource(object):
# TODO
self
.
_reset_state
()
@
abstractmethod
def
_reset_state
(
self
):
pass
...
...
tensorpack/graph_builder/predictor_factory.py
View file @
7699fd9b
...
...
@@ -7,7 +7,6 @@ from ..utils import logger
from
..tfutils.common
import
get_op_tensor_name
,
get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..tfutils.collection
import
freeze_collection
from
..predict
import
OnlinePredictor
from
..utils.naming
import
TOWER_FREEZE_KEYS
__all__
=
[
'PredictorFactory'
]
...
...
@@ -41,7 +40,7 @@ class PredictorTowerHandle(object):
class
PredictorFactory
(
object
):
""" Make predictors from :class:`ModelDesc`."""
def
__init__
(
self
,
model
,
towers
,
vs_name
):
def
__init__
(
self
,
model
,
towers
,
vs_name
=
''
):
"""
Args:
model (ModelDesc):
...
...
@@ -97,4 +96,5 @@ class PredictorFactory(object):
in_tensors
=
handle
.
get_tensors
(
input_names
)
out_tensors
=
handle
.
get_tensors
(
output_names
)
from
..predict
import
OnlinePredictor
# noqa TODO
return
OnlinePredictor
(
in_tensors
,
out_tensors
)
tensorpack/predict/multigpu.py
View file @
7699fd9b
...
...
@@ -3,9 +3,12 @@
# File: multigpu.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
..utils
import
logger
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
.base
import
OnlinePredictor
,
build_prediction_graph
,
PredictorTowerBuilder
from
..tfutils
import
TowerContext
from
..graph_builder.predictor_factory
import
PredictorFactory
from
..graph_builder.input_source
import
PlaceholderInput
from
.base
import
OnlinePredictor
__all__
=
[
'MultiTowerOfflinePredictor'
,
'DataParallelOfflinePredictor'
]
...
...
@@ -23,20 +26,24 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
assert
len
(
towers
)
>
0
self
.
graph
=
config
.
_maybe_create_graph
()
self
.
predictors
=
[]
self
.
return_input
=
config
.
return_input
with
self
.
graph
.
as_default
():
placeholder_names
=
set
([
k
.
name
for
k
in
config
.
model
.
get_inputs_desc
()])
handles
=
[]
factory
=
PredictorFactory
(
config
.
model
,
towers
)
for
idx
,
t
in
enumerate
(
towers
):
tower_name
=
TowerContext
.
get_predict_tower_name
(
t
)
device
=
'/gpu:'
+
str
(
t
)
def
fn
(
_
):
config
.
model
.
build_graph
(
config
.
model
.
get_reused_placehdrs
())
build_prediction_graph
(
fn
,
towers
)
# TODO smarter TowerContext?
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
):
handles
.
append
(
factory
.
build
(
tower_name
,
device
)
)
self
.
sess
=
config
.
session_creator
.
create_session
()
config
.
session_init
.
init
(
self
.
sess
)
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
for
k
in
towers
:
input_tensors
=
get_tensor_fn
(
placeholder_names
,
config
.
input_names
,
k
)
output_tensors
=
get_tensor_fn
(
placeholder_names
,
config
.
output_names
,
k
)
for
h
in
handles
:
input_tensors
=
h
.
get_tensors
(
config
.
input_names
)
output_tensors
=
h
.
get_tensors
(
config
.
output_names
)
self
.
predictors
.
append
(
OnlinePredictor
(
input_tensors
,
output_tensors
,
config
.
return_input
,
self
.
sess
))
...
...
@@ -79,23 +86,20 @@ class DataParallelOfflinePredictor(OnlinePredictor):
"""
self
.
graph
=
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
input_
name
s
=
[]
input_
tensor
s
=
[]
output_tensors
=
[]
def
build_tower
(
k
):
towername
=
TowerContext
.
get_predict_tower_name
(
k
)
# inputs (placeholders) for this tower only
input_tensors
=
config
.
model
.
build_placeholders
(
prefix
=
towername
+
'/'
)
config
.
model
.
build_graph
(
input_tensors
)
input_names
.
extend
([
t
.
name
for
t
in
input_tensors
])
output_tensors
.
extend
(
get_tensors_by_names
(
[
towername
+
'/'
+
n
for
n
in
config
.
output_names
]))
build_prediction_graph
(
build_tower
,
towers
)
input_tensors
=
get_tensors_by_names
(
input_names
)
factory
=
PredictorFactory
(
config
.
model
,
towers
)
for
idx
,
t
in
enumerate
(
towers
):
tower_name
=
TowerContext
.
get_predict_tower_name
(
t
)
device
=
'/gpu:'
+
str
(
t
)
input
=
PlaceholderInput
(
tower_name
+
'/'
)
input
.
setup
(
config
.
model
.
get_inputs_desc
())
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
):
h
=
factory
.
build
(
tower_name
,
device
,
)
input_tensors
.
extend
(
h
.
get_tensors
(
config
.
input_names
))
output_tensors
.
extend
(
h
.
get_tensors
(
config
.
output_names
))
sess
=
config
.
session_creator
.
create_session
()
config
.
session_init
.
init
(
sess
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment