Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fa025551
Commit
fa025551
authored
Feb 22, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor predictors.
parent
dbfa9982
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
123 additions
and
74 deletions
+123
-74
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+53
-19
tensorpack/predict/config.py
tensorpack/predict/config.py
+14
-3
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+36
-14
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+2
-1
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+1
-0
tensorpack/train/predict.py
tensorpack/train/predict.py
+15
-36
No files found.
tensorpack/callbacks/inference_runner.py
View file @
fa025551
...
@@ -183,6 +183,7 @@ class FeedfreeInferenceRunner(Triggerable):
...
@@ -183,6 +183,7 @@ class FeedfreeInferenceRunner(Triggerable):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# tensors
self
.
_find_input_tensors
()
# tensors
# TODO reuse predictor code
# overwrite the FeedfreeInferenceRunner name scope
# overwrite the FeedfreeInferenceRunner name scope
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
),
\
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
),
\
tf
.
name_scope
(
None
),
\
tf
.
name_scope
(
None
),
\
...
@@ -190,7 +191,7 @@ class FeedfreeInferenceRunner(Triggerable):
...
@@ -190,7 +191,7 @@ class FeedfreeInferenceRunner(Triggerable):
def
fn
(
_
):
def
fn
(
_
):
self
.
trainer
.
model
.
build_graph
(
self
.
_input_tensors
)
self
.
trainer
.
model
.
build_graph
(
self
.
_input_tensors
)
build_prediction_graph
(
fn
,
[
0
],
prefix
=
self
.
_prefix
)
build_prediction_graph
(
fn
,
[
0
],
prefix
=
self
.
_prefix
)
self
.
_tower_prefix
=
TowerContext
.
get_predict_tower_name
(
self
.
_prefix
,
0
)
self
.
_tower_prefix
=
TowerContext
.
get_predict_tower_name
(
0
,
self
.
_prefix
)
self
.
_find_output_tensors
()
self
.
_find_output_tensors
()
...
...
tensorpack/predict/base.py
View file @
fa025551
...
@@ -8,11 +8,15 @@ import tensorflow as tf
...
@@ -8,11 +8,15 @@ import tensorflow as tf
import
six
import
six
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.naming
import
SUMMARY_BACKUP_KEYS
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
..tfutils.collection
import
freeze_collection
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
'OnlinePredictor'
,
'OfflinePredictor'
,
'OnlinePredictor'
,
'OfflinePredictor'
,
'get_predict_func'
,
'get_predict_func'
,
'PredictorTowerBuilder'
,
'build_prediction_graph'
,
'build_prediction_graph'
,
]
]
...
@@ -119,14 +123,15 @@ class OnlinePredictor(PredictorBase):
...
@@ -119,14 +123,15 @@ class OnlinePredictor(PredictorBase):
class
OfflinePredictor
(
OnlinePredictor
):
class
OfflinePredictor
(
OnlinePredictor
):
""" A predictor built from a given config, in a new graph. """
""" A predictor built from a given config.
A sinlge-tower model will be built without any prefix. """
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""
"""
Args:
Args:
config (PredictConfig): the config to use.
config (PredictConfig): the config to use.
"""
"""
self
.
graph
=
tf
.
G
raph
()
self
.
graph
=
config
.
_maybe_create_g
raph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
input_placehdrs
=
config
.
model
.
get_reused_placehdrs
()
input_placehdrs
=
config
.
model
.
get_reused_placehdrs
()
with
TowerContext
(
''
,
False
):
with
TowerContext
(
''
,
False
):
...
@@ -148,23 +153,52 @@ def get_predict_func(config):
...
@@ -148,23 +153,52 @@ def get_predict_func(config):
return
OfflinePredictor
(
config
)
return
OfflinePredictor
(
config
)
def
build_prediction_graph
(
build_tower_fn
,
towers
=
[
0
],
prefix
=
''
):
class
PredictorTowerBuilder
(
object
):
"""
A builder which caches the predictor tower it has built.
"""
def
__init__
(
self
,
build_tower_fn
,
prefix
=
''
):
"""
"""
Build graph on each tower.
Args:
Args:
build_tower_fn: a function that will be called inside each tower,
build_tower_fn: a function that will be called inside each tower, taking tower id as the argument.
taking tower id as the argument.
towers: a list of relative GPU id.
prefix: an extra prefix in tower name. The final tower prefix will be
prefix: an extra prefix in tower name. The final tower prefix will be
determined by :meth:`TowerContext.get_predict_tower_name`.
determined by :meth:`TowerContext.get_predict_tower_name`.
"""
"""
for
idx
,
k
in
enumerate
(
towers
):
self
.
_fn
=
build_tower_fn
logger
.
info
(
self
.
_prefix
=
prefix
"Building prediction graph for towerid={} with prefix='{}' ..."
.
format
(
k
,
prefix
))
towername
=
TowerContext
.
get_predict_tower_name
(
prefix
,
k
)
@
memoized
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
def
build
(
self
,
tower
):
TowerContext
(
towername
,
is_training
=
False
),
\
"""
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
Args:
reuse
=
True
if
idx
>
0
else
None
):
tower (int): the tower will be built on device '/gpu:{tower}', or
build_tower_fn
(
k
)
'/cpu:0' if tower is -1.
"""
towername
=
TowerContext
.
get_predict_tower_name
(
tower
,
self
.
_prefix
)
if
self
.
_prefix
:
msg
=
"Building predictor graph {} on gpu={} with prefix='{}' ..."
.
format
(
towername
,
tower
,
self
.
_prefix
)
else
:
msg
=
"Building predictor graph {} on gpu={} ..."
.
format
(
towername
,
tower
)
logger
.
info
(
msg
)
# No matter where this get called, clear any existing name scope.
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
),
\
tf
.
device
(
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
),
\
TowerContext
(
towername
,
is_training
=
False
):
self
.
_fn
(
tower
)
def
build_prediction_graph
(
build_tower_fn
,
towers
=
[
0
],
prefix
=
''
):
"""
Execute `build_tower_fn` on each tower.
Just a wrapper on :class:`PredictorTowerBuilder` to run on several towers
together.
"""
builder
=
PredictorTowerBuilder
(
build_tower_fn
,
prefix
)
for
idx
,
t
in
enumerate
(
towers
):
# The first variable scope may or may not reuse (depending on the existing
# context), but the rest have to reuse.
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
if
idx
>
0
else
None
):
builder
.
build
(
t
)
tensorpack/predict/config.py
View file @
fa025551
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# File: config.py
# File: config.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
six
import
six
from
..models
import
ModelDesc
from
..models
import
ModelDesc
...
@@ -17,10 +18,12 @@ class PredictConfig(object):
...
@@ -17,10 +18,12 @@ class PredictConfig(object):
def
__init__
(
self
,
model
,
def
__init__
(
self
,
model
,
session_creator
=
None
,
session_creator
=
None
,
session_init
=
None
,
session_init
=
None
,
session_config
=
None
,
input_names
=
None
,
input_names
=
None
,
output_names
=
None
,
output_names
=
None
,
return_input
=
False
):
return_input
=
False
,
create_graph
=
True
,
session_config
=
None
,
# deprecated
):
"""
"""
Args:
Args:
model (ModelDesc): the model to use.
model (ModelDesc): the model to use.
...
@@ -32,7 +35,9 @@ class PredictConfig(object):
...
@@ -32,7 +35,9 @@ class PredictConfig(object):
inputs of the model.
inputs of the model.
output_names (list): a list of names of the output tensors to predict, the
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
tensors can be any computable tensor in the graph.
return_input: same as in :attr:`PredictorBase.return_input`.
return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph
when then predictor is first initialized.
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
@@ -68,3 +73,9 @@ class PredictConfig(object):
...
@@ -68,3 +73,9 @@ class PredictConfig(object):
assert
len
(
self
.
output_names
),
self
.
output_names
assert
len
(
self
.
output_names
),
self
.
output_names
self
.
return_input
=
bool
(
return_input
)
self
.
return_input
=
bool
(
return_input
)
self
.
create_graph
=
bool
(
create_graph
)
def
_maybe_create_graph
(
self
):
if
self
.
create_graph
:
return
tf
.
Graph
()
return
tf
.
get_default_graph
()
tensorpack/predict/multigpu.py
View file @
fa025551
...
@@ -3,9 +3,8 @@
...
@@ -3,9 +3,8 @@
# File: multigpu.py
# File: multigpu.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
..utils
import
logger
from
..tfutils
import
get_tensors_by_names
,
TowerContext
,
get_op_tensor_name
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
.base
import
OnlinePredictor
,
build_prediction_graph
from
.base
import
OnlinePredictor
,
build_prediction_graph
__all__
=
[
'MultiTowerOfflinePredictor'
,
__all__
=
[
'MultiTowerOfflinePredictor'
,
...
@@ -21,10 +20,12 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -21,10 +20,12 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
config (PredictConfig): the config to use.
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
towers: a list of relative GPU id.
"""
"""
self
.
graph
=
tf
.
Graph
()
assert
len
(
towers
)
>
0
self
.
graph
=
config
.
_maybe_create_graph
()
self
.
predictors
=
[]
self
.
predictors
=
[]
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
# TODO backup summary keys?
placeholder_names
=
set
([
k
.
name
for
k
in
config
.
model
.
get_inputs_desc
()])
def
fn
(
_
):
def
fn
(
_
):
config
.
model
.
build_graph
(
config
.
model
.
get_reused_placehdrs
())
config
.
model
.
build_graph
(
config
.
model
.
get_reused_placehdrs
())
build_prediction_graph
(
fn
,
towers
)
build_prediction_graph
(
fn
,
towers
)
...
@@ -32,25 +33,46 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -32,25 +33,46 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self
.
sess
=
config
.
session_creator
.
create_session
()
self
.
sess
=
config
.
session_creator
.
create_session
()
config
.
session_init
.
init
(
self
.
sess
)
config
.
session_init
.
init
(
self
.
sess
)
input_tensors
=
get_tensors_by_names
(
config
.
input_names
)
get_tensor_fn
=
MultiTowerOfflinePredictor
.
get_tensors_maybe_in_tower
for
k
in
towers
:
for
k
in
towers
:
output_tensors
=
get_tensors_by_names
(
input_tensors
=
get_tensor_fn
(
placeholder_names
,
config
.
input_names
,
k
)
[
TowerContext
.
get_predict_towre_name
(
''
,
k
)
+
'/'
+
n
output_tensors
=
get_tensor_fn
(
placeholder_names
,
config
.
output_names
,
k
)
for
n
in
config
.
output_names
])
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
predictors
.
append
(
OnlinePredictor
(
input_tensors
,
output_tensors
,
config
.
return_input
,
self
.
sess
))
input_tensors
,
output_tensors
,
config
.
return_input
,
self
.
sess
))
@
staticmethod
def
get_tensors_maybe_in_tower
(
placeholder_names
,
names
,
k
):
def
maybe_inside_tower
(
name
):
name
=
get_op_tensor_name
(
name
)[
0
]
if
name
in
placeholder_names
:
return
name
else
:
# if the name is not a placeholder, use it's name in each tower
return
TowerContext
.
get_predict_tower_name
(
k
)
+
'/'
+
name
names
=
map
(
maybe_inside_tower
,
names
)
tensors
=
get_tensors_by_names
(
names
)
return
tensors
def
_do_call
(
self
,
dp
):
def
_do_call
(
self
,
dp
):
# use the first tower for compatible PredictorBase interface
# use the first tower for compatible PredictorBase interface
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
def
get_predictors
(
self
,
n
):
def
get_predictor
(
self
,
n
):
"""
Returns:
PredictorBase: the nth predictor on the nth tower.
"""
l
=
len
(
self
.
predictors
)
if
n
>=
l
:
logger
.
warn
(
"n > #towers, will assign predictor to GPU by round-robin"
)
return
[
self
.
predictors
[
k
%
l
]
for
k
in
range
(
n
)]
def
get_predictors
(
self
):
"""
"""
Returns:
Returns:
PredictorBase: the nth predictor on the nth GPU.
list[PredictorBase]: a list of predictor
"""
"""
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
return
self
.
predictors
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
...
@@ -66,7 +88,7 @@ class DataParallelOfflinePredictor(OnlinePredictor):
...
@@ -66,7 +88,7 @@ class DataParallelOfflinePredictor(OnlinePredictor):
config (PredictConfig): the config to use.
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
towers: a list of relative GPU id.
"""
"""
self
.
graph
=
tf
.
G
raph
()
self
.
graph
=
config
.
_maybe_create_g
raph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
input_names
=
[]
input_names
=
[]
output_tensors
=
[]
output_tensors
=
[]
...
...
tensorpack/tfutils/tower.py
View file @
fa025551
...
@@ -73,7 +73,7 @@ class TowerContext(object):
...
@@ -73,7 +73,7 @@ class TowerContext(object):
return
graph
.
get_tensor_by_name
(
newname
)
return
graph
.
get_tensor_by_name
(
newname
)
@
staticmethod
@
staticmethod
def
get_predict_tower_name
(
prefix
,
towerid
=
0
):
def
get_predict_tower_name
(
towerid
=
0
,
prefix
=
''
):
"""
"""
Args:
Args:
prefix(str): an alphanumeric prefix.
prefix(str): an alphanumeric prefix.
...
@@ -91,6 +91,7 @@ class TowerContext(object):
...
@@ -91,6 +91,7 @@ class TowerContext(object):
assert
_CurrentTowerContext
is
None
,
\
assert
_CurrentTowerContext
is
None
,
\
"Nesting TowerContext!"
"Nesting TowerContext!"
_CurrentTowerContext
=
self
_CurrentTowerContext
=
self
# TODO enter name_scope(None) first
if
len
(
self
.
_name
):
if
len
(
self
.
_name
):
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
return
self
.
_scope
.
__enter__
()
return
self
.
_scope
.
__enter__
()
...
...
tensorpack/train/input_data.py
View file @
fa025551
...
@@ -124,6 +124,7 @@ class QueueInput(FeedfreeInput):
...
@@ -124,6 +124,7 @@ class QueueInput(FeedfreeInput):
def
size
(
self
):
def
size
(
self
):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
# TODO XXX use input data mapping. not all placeholders are needed
def
_setup
(
self
,
trainer
):
def
_setup
(
self
,
trainer
):
self
.
input_placehdrs
=
trainer
.
model
.
get_reused_placehdrs
()
self
.
input_placehdrs
=
trainer
.
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
...
...
tensorpack/train/predict.py
View file @
fa025551
...
@@ -4,11 +4,8 @@
...
@@ -4,11 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils
import
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..predict
import
(
OnlinePredictor
,
from
..tfutils.collection
import
freeze_collection
PredictorTowerBuilder
,
MultiTowerOfflinePredictor
)
from
..utils.argtools
import
memoized
from
..tfutils
import
get_tensors_by_names
,
get_op_tensor_name
from
..predict
import
OnlinePredictor
,
build_prediction_graph
__all__
=
[
'PredictorFactory'
]
__all__
=
[
'PredictorFactory'
]
...
@@ -23,45 +20,27 @@ class PredictorFactory(object):
...
@@ -23,45 +20,27 @@ class PredictorFactory(object):
"""
"""
self
.
model
=
trainer
.
model
self
.
model
=
trainer
.
model
self
.
towers
=
trainer
.
config
.
predict_tower
self
.
towers
=
trainer
.
config
.
predict_tower
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_reused_placehdrs
())
self
.
_tower_builder
=
PredictorTowerBuilder
(
fn
)
assert
isinstance
(
self
.
towers
,
list
)
assert
isinstance
(
self
.
towers
,
list
)
# TODO sess option
# TODO sess option
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
"""
"""
Args:
Args:
tower (int): need the kth tower (not the gpu id)
tower (int): need the kth tower (not the gpu id
, but the id in TrainConfig.predict_tower
)
Returns:
Returns:
an online predictor (which has to be used under a default session)
an online predictor (which has to be used under a default session)
"""
"""
self
.
_build_predict_tower
()
tower
=
self
.
towers
[
tower
]
# TODO is it good?
tower
=
self
.
towers
[
tower
]
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
# just ensure the tower exists. won't rebuild
self
.
_tower_builder
.
build
(
tower
)
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
get_tensor_fn
=
MultiTowerOfflinePredictor
.
get_tensors_maybe_in_tower
def
get_name_in_tower
(
name
):
in_tensors
=
get_tensor_fn
(
placeholder_names
,
input_names
,
tower
)
return
PREDICT_TOWER
+
str
(
tower
)
+
'/'
+
name
out_tensors
=
get_tensor_fn
(
placeholder_names
,
output_names
,
tower
)
return
OnlinePredictor
(
in_tensors
,
out_tensors
)
def
maybe_inside_tower
(
name
):
name
=
get_op_tensor_name
(
name
)[
0
]
if
name
in
placeholder_names
:
return
name
else
:
return
get_name_in_tower
(
name
)
input_names
=
map
(
maybe_inside_tower
,
input_names
)
raw_input_tensors
=
get_tensors_by_names
(
input_names
)
output_names
=
map
(
get_name_in_tower
,
output_names
)
output_tensors
=
get_tensors_by_names
(
output_names
)
return
OnlinePredictor
(
raw_input_tensors
,
output_tensors
)
@
memoized
def
_build_predict_tower
(
self
):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER'
# should always be the outermost name scope
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_reused_placehdrs
())
build_prediction_graph
(
fn
,
self
.
towers
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment