Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0fdef168
Commit
0fdef168
authored
Oct 30, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove PredictorFactory, and build offline predictor by tower_func
parent
c2c895ab
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
89 deletions
+38
-89
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+1
-71
tensorpack/predict/base.py
tensorpack/predict/base.py
+3
-3
tensorpack/predict/config.py
tensorpack/predict/config.py
+22
-7
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+12
-8
No files found.
tensorpack/graph_builder/predictor_factory.py
View file @
0fdef168
...
@@ -6,8 +6,7 @@ import tensorflow as tf
...
@@ -6,8 +6,7 @@ import tensorflow as tf
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils.tower
import
TowerContext
,
TowerFuncWrapper
from
..tfutils.tower
import
TowerContext
from
..input_source
import
PlaceholderInput
from
.training
import
GraphBuilder
from
.training
import
GraphBuilder
__all__
=
[
'SimplePredictBuilder'
]
__all__
=
[
'SimplePredictBuilder'
]
...
@@ -58,72 +57,3 @@ class SimplePredictBuilder(GraphBuilder):
...
@@ -58,72 +57,3 @@ class SimplePredictBuilder(GraphBuilder):
inputs
=
input
.
get_input_tensors
()
inputs
=
input
.
get_input_tensors
()
assert
isinstance
(
inputs
,
(
list
,
tuple
)),
inputs
assert
isinstance
(
inputs
,
(
list
,
tuple
)),
inputs
return
tower_fn
(
*
inputs
)
return
tower_fn
(
*
inputs
)
class
PredictorFactory
(
object
):
""" Make predictors from :class:`ModelDesc`."""
def
__init__
(
self
,
model
,
vs_name
=
''
):
"""
Args:
model (ModelDesc):
vs_name (str):
"""
self
.
_model
=
model
self
.
_vs_name
=
vs_name
self
.
_names_built
=
{}
def
build
(
self
,
tower_name
,
device
,
input
=
None
):
"""
Args:
tower_name (str):
device(str):
input (InputSource): must be setup already. If None, will use InputDesc from the model.
"""
logger
.
info
(
"Building predictor tower '{}' on device {} ..."
.
format
(
tower_name
,
device
))
assert
tower_name
not
in
self
.
_names_built
,
\
"Prediction tower with name '{}' already exists!"
.
format
(
tower_name
)
with
tf
.
device
(
device
),
\
TowerContext
(
tower_name
,
is_training
=
False
):
inputs_desc
=
self
.
_model
.
get_inputs_desc
()
if
input
is
None
:
input
=
PlaceholderInput
()
input
.
setup
(
inputs_desc
)
inputs
=
input
.
get_input_tensors
()
assert
isinstance
(
inputs
,
(
list
,
tuple
)),
inputs
def
tower_func
(
*
inputs
):
self
.
_model
.
build_graph
(
inputs
)
tower_func
=
TowerFuncWrapper
(
tower_func
,
inputs_desc
)
tower_func
(
*
inputs
)
self
.
_names_built
[
tower_name
]
=
tower_func
.
towers
[
0
]
return
self
.
_names_built
[
tower_name
]
def
has_built
(
self
,
tower_name
):
return
tower_name
in
self
.
_names_built
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
"""
Args:
tower (int): use device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns:
an online predictor (which has to be used under a default session)
"""
tower_name
=
'towerp{}'
.
format
(
tower
)
device
=
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
# use a previously-built tower
# TODO check conflict with inference runner??
if
tower_name
not
in
self
.
_names_built
:
with
tf
.
variable_scope
(
self
.
_vs_name
,
reuse
=
True
):
handle
=
self
.
build
(
tower_name
,
device
)
else
:
handle
=
self
.
_names_built
[
tower_name
]
in_tensors
=
handle
.
get_tensors
(
input_names
)
out_tensors
=
handle
.
get_tensors
(
output_names
)
from
..predict
import
OnlinePredictor
# noqa TODO
return
OnlinePredictor
(
in_tensors
,
out_tensors
)
tensorpack/predict/base.py
View file @
0fdef168
...
@@ -146,7 +146,7 @@ class OnlinePredictor(PredictorBase):
...
@@ -146,7 +146,7 @@ class OnlinePredictor(PredictorBase):
class
OfflinePredictor
(
OnlinePredictor
):
class
OfflinePredictor
(
OnlinePredictor
):
""" A predictor built from a given config.
""" A predictor built from a given config.
A sin
lg
e-tower model will be built without any prefix. """
A sin
gl
e-tower model will be built without any prefix. """
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""
"""
...
@@ -156,9 +156,9 @@ class OfflinePredictor(OnlinePredictor):
...
@@ -156,9 +156,9 @@ class OfflinePredictor(OnlinePredictor):
self
.
graph
=
config
.
_maybe_create_graph
()
self
.
graph
=
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
input
=
PlaceholderInput
()
input
=
PlaceholderInput
()
input
.
setup
(
config
.
model
.
get_inputs_desc
()
)
input
.
setup
(
config
.
inputs_desc
)
with
TowerContext
(
''
,
is_training
=
False
):
with
TowerContext
(
''
,
is_training
=
False
):
config
.
model
.
build_graph
(
input
.
get_input_tensors
())
config
.
tower_func
(
*
input
.
get_input_tensors
())
input_tensors
=
get_tensors_by_names
(
config
.
input_names
)
input_tensors
=
get_tensors_by_names
(
config
.
input_names
)
output_tensors
=
get_tensors_by_names
(
config
.
output_names
)
output_tensors
=
get_tensors_by_names
(
config
.
output_names
)
...
...
tensorpack/predict/config.py
View file @
0fdef168
...
@@ -7,14 +7,17 @@ import six
...
@@ -7,14 +7,17 @@ import six
from
..graph_builder
import
ModelDescBase
from
..graph_builder
import
ModelDescBase
from
..tfutils
import
get_default_sess_config
from
..tfutils
import
get_default_sess_config
from
..tfutils.tower
import
TowerFuncWrapper
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..tfutils.sesscreate
import
NewSessionCreator
__all__
=
[
'PredictConfig'
]
__all__
=
[
'PredictConfig'
]
class
PredictConfig
(
object
):
class
PredictConfig
(
object
):
def
__init__
(
self
,
model
,
def
__init__
(
self
,
model
=
None
,
inputs_desc
=
None
,
tower_func
=
None
,
session_creator
=
None
,
session_creator
=
None
,
session_init
=
None
,
session_init
=
None
,
input_names
=
None
,
input_names
=
None
,
...
@@ -24,9 +27,12 @@ class PredictConfig(object):
...
@@ -24,9 +27,12 @@ class PredictConfig(object):
):
):
"""
"""
Args:
Args:
model (ModelDescBase): the model to use.
model (ModelDescBase): the model to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors
session_creator (tf.train.SessionCreator): how to create the
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`
sesscreate.New
SessionCreator()`.
session. Defaults to :class:`
tf.train.Chief
SessionCreator()`.
session_init (SessionInit): how to initialize variables of the session.
session_init (SessionInit): how to initialize variables of the session.
Defaults to do nothing.
Defaults to do nothing.
input_names (list): a list of input tensor names. Defaults to all
input_names (list): a list of input tensor names. Defaults to all
...
@@ -36,11 +42,20 @@ class PredictConfig(object):
...
@@ -36,11 +42,20 @@ class PredictConfig(object):
return_input (bool): same as in :attr:`PredictorBase.return_input`.
return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph
create_graph (bool): create a new graph, or use the default graph
when then predictor is first initialized.
when then predictor is first initialized.
You need to set either `model`, or `inputs_desc` plus `tower_func`.
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
self
.
model
=
model
if
model
is
not
None
:
assert_type
(
self
.
model
,
ModelDescBase
)
assert_type
(
model
,
ModelDescBase
)
assert
inputs_desc
is
None
and
tower_func
is
None
self
.
inputs_desc
=
model
.
get_inputs_desc
()
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
self
.
inputs_desc
)
else
:
assert
inputs_desc
is
not
None
and
tower_func
is
not
None
self
.
inputs_desc
=
inputs_desc
self
.
tower_func
=
TowerFuncWrapper
(
tower_func
,
inputs_desc
)
if
session_init
is
None
:
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
session_init
=
JustCurrentSession
()
...
@@ -48,7 +63,7 @@ class PredictConfig(object):
...
@@ -48,7 +63,7 @@ class PredictConfig(object):
assert_type
(
self
.
session_init
,
SessionInit
)
assert_type
(
self
.
session_init
,
SessionInit
)
if
session_creator
is
None
:
if
session_creator
is
None
:
self
.
session_creator
=
New
SessionCreator
(
config
=
get_default_sess_config
())
self
.
session_creator
=
tf
.
train
.
Chief
SessionCreator
(
config
=
get_default_sess_config
())
else
:
else
:
self
.
session_creator
=
session_creator
self
.
session_creator
=
session_creator
...
...
tensorpack/predict/multigpu.py
View file @
0fdef168
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils
import
logger
from
..graph_builder.predictor_factory
import
PredictorFactory
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
from
.base
import
OnlinePredictor
from
.base
import
OnlinePredictor
...
@@ -28,13 +28,17 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -28,13 +28,17 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self
.
return_input
=
config
.
return_input
self
.
return_input
=
config
.
return_input
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
handles
=
[]
handles
=
[]
factory
=
PredictorFactory
(
config
.
model
,
towers
)
input
=
PlaceholderInput
()
input
.
setup
(
config
.
inputs_desc
)
for
idx
,
t
in
enumerate
(
towers
):
for
idx
,
t
in
enumerate
(
towers
):
tower_name
=
'tower'
+
str
(
t
)
tower_name
=
'tower'
+
str
(
t
)
device
=
'/gpu:'
+
str
(
t
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
):
handles
.
append
(
factory
.
build
(
tower_name
,
device
))
builder
=
SimplePredictBuilder
(
ns_name
=
tower_name
,
device
=
t
)
builder
.
build
(
input
,
config
.
tower_func
)
handles
.
append
(
config
.
tower_func
.
towers
[
-
1
])
self
.
sess
=
config
.
session_creator
.
create_session
()
self
.
sess
=
config
.
session_creator
.
create_session
()
config
.
session_init
.
init
(
self
.
sess
)
config
.
session_init
.
init
(
self
.
sess
)
...
@@ -87,15 +91,15 @@ class DataParallelOfflinePredictor(OnlinePredictor):
...
@@ -87,15 +91,15 @@ class DataParallelOfflinePredictor(OnlinePredictor):
input_tensors
=
[]
input_tensors
=
[]
output_tensors
=
[]
output_tensors
=
[]
factory
=
PredictorFactory
(
config
.
model
,
towers
)
for
idx
,
t
in
enumerate
(
towers
):
for
idx
,
t
in
enumerate
(
towers
):
tower_name
=
'tower'
+
str
(
t
)
tower_name
=
'tower'
+
str
(
t
)
device
=
'/gpu:'
+
str
(
t
)
input
=
PlaceholderInput
(
tower_name
+
'/'
)
input
=
PlaceholderInput
(
tower_name
+
'/'
)
input
.
setup
(
config
.
model
.
get_inputs_desc
()
)
input
.
setup
(
config
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
):
h
=
factory
.
build
(
tower_name
,
device
,
)
builder
=
SimplePredictBuilder
(
ns_name
=
tower_name
,
device
=
t
)
builder
.
build
(
input
,
config
.
tower_func
)
h
=
config
.
tower_func
.
towers
[
-
1
]
input_tensors
.
extend
(
h
.
get_tensors
(
config
.
input_names
))
input_tensors
.
extend
(
h
.
get_tensors
(
config
.
input_names
))
output_tensors
.
extend
(
h
.
get_tensors
(
config
.
output_names
))
output_tensors
.
extend
(
h
.
get_tensors
(
config
.
output_names
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment