Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4ee1e735
Commit
4ee1e735
authored
Jul 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
PredictorFactory build tower by itself.
parent
e839c50d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
73 additions
and
38 deletions
+73
-38
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+65
-34
tensorpack/predict/base.py
tensorpack/predict/base.py
+2
-1
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+2
-2
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+4
-1
No files found.
tensorpack/graph_builder/predictor_factory.py
View file @
4ee1e735
...
...
@@ -3,58 +3,89 @@
# File: predictor_factory.py
import
tensorflow
as
tf
# from ..tfutils.tower import TowerContext
from
..predict
import
(
OnlinePredictor
,
PredictorTowerBuilder
)
from
..utils
import
logger
from
..tfutils.common
import
get_op_tensor_name
,
get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..tfutils.collection
import
freeze_collection
from
..predict
import
OnlinePredictor
from
..utils.naming
import
TOWER_FREEZE_KEYS
__all__
=
[
'PredictorFactory'
]
# class PredictorTowerBuilder(object):
# def __init__(self, model):
# self._model = model
# self._towers = []
#
# def build(self, tower_name, device, input=None):
# with tf.device(device), TowerContext(tower_name, is_training=False):
# if input is None:
# input = self._model.get_reused_placehdrs()
# self._model.build_graph(input)
#
#
# SMART
class
PredictorTowerHandle
(
object
):
def
__init__
(
self
,
tower_name
,
input_tensors
):
self
.
_tower_name
=
tower_name
self
.
_input_tensors
=
input_tensors
self
.
_input_names
=
[
get_op_tensor_name
(
k
.
name
)[
1
]
for
k
in
input_tensors
]
def
get_tensors
(
self
,
names
):
def
maybe_inside_tower
(
name
):
name
=
get_op_tensor_name
(
name
)[
1
]
if
name
in
self
.
_input_names
:
return
name
else
:
# if the name is not a placeholder, use it's name in each tower
return
self
.
_tower_name
+
'/'
+
name
names
=
list
(
map
(
maybe_inside_tower
,
names
))
tensors
=
get_tensors_by_names
(
names
)
return
tensors
class
PredictorFactory
(
object
):
""" Make predictors from :class:`ModelDesc` and cache them."""
def
__init__
(
self
,
model
,
towers
,
vs_name
):
"""
Args:
model (ModelDesc):
towers (list[int]): list of available gpu id
vs_name (str):
"""
self
.
model
=
model
self
.
towers
=
towers
self
.
vs_name
=
vs_name
assert
isinstance
(
towers
,
list
),
towers
self
.
_model
=
model
self
.
_towers
=
towers
self
.
_vs_name
=
vs_name
self
.
_names_built
=
{}
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_reused_placehdrs
())
self
.
_tower_builder
=
PredictorTowerBuilder
(
fn
)
assert
isinstance
(
self
.
towers
,
list
),
self
.
towers
def
build
(
self
,
tower_name
,
device
,
input
=
None
):
logger
.
info
(
"Building predictor graph {} on device {} ..."
.
format
(
tower_name
,
device
))
assert
tower_name
not
in
self
.
_names_built
with
tf
.
device
(
device
),
\
TowerContext
(
tower_name
,
is_training
=
False
),
\
freeze_collection
(
TOWER_FREEZE_KEYS
):
if
input
is
None
:
input
=
self
.
_model
.
get_reused_placehdrs
()
else
:
input
=
input
.
get_input_tensors
()
assert
isinstance
(
input
,
(
list
,
tuple
)),
input
self
.
_model
.
build_graph
(
input
)
self
.
_names_built
[
tower_name
]
=
PredictorTowerHandle
(
tower_name
,
input
)
return
self
.
_names_built
[
tower_name
]
def
has_built
(
self
,
tower_name
):
return
tower_name
in
self
.
_names_built
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
"""
Args:
tower (int): need the kth tower (not the gpu id, but the id in TrainConfig.predict_tower)
Returns:
an online predictor (which has to be used under
a
default session)
an online predictor (which has to be used under
the
default session)
"""
tower
=
self
.
towers
[
tower
]
# just ensure the tower exists. won't rebuild (memoized)
with
tf
.
variable_scope
(
self
.
vs_name
,
reuse
=
True
):
self
.
_tower_builder
.
build
(
tower
)
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
in_tensors
=
get_tensor_fn
(
placeholder_names
,
input_names
,
tower
)
out_tensors
=
get_tensor_fn
(
placeholder_names
,
output_names
,
tower
)
tower
=
self
.
_towers
[
tower
]
device
=
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
tower_name
=
TowerContext
.
get_predict_tower_name
(
max
(
tower
,
0
))
# XXX
# use a previously-built tower
# TODO conflict with inference runner??
if
not
self
.
has_built
(
tower_name
):
with
tf
.
variable_scope
(
self
.
_vs_name
,
reuse
=
True
):
handle
=
self
.
build
(
tower_name
,
device
)
else
:
handle
=
self
.
_names_built
[
tower_name
]
in_tensors
=
handle
.
get_tensors
(
input_names
)
out_tensors
=
handle
.
get_tensors
(
output_names
)
return
OnlinePredictor
(
in_tensors
,
out_tensors
)
tensorpack/predict/base.py
View file @
4ee1e735
...
...
@@ -10,7 +10,8 @@ import six
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.naming
import
TOWER_FREEZE_KEYS
from
..tfutils
import
get_tensors_by_names
,
TowerContext
,
get_op_tensor_name
from
..tfutils.common
import
get_tensors_by_names
,
get_op_tensor_name
from
..tfutils.tower
import
TowerContext
from
..tfutils.collection
import
freeze_collection
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
...
...
tensorpack/predict/multigpu.py
View file @
4ee1e735
...
...
@@ -47,7 +47,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def
get_predictor
(
self
,
n
):
"""
Returns:
PredictorBase
: the nth predictor on the nth tower.
OnlinePredictor
: the nth predictor on the nth tower.
"""
l
=
len
(
self
.
predictors
)
if
n
>=
l
:
...
...
@@ -57,7 +57,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def
get_predictors
(
self
):
"""
Returns:
list[
PredictorBase
]: a list of predictor
list[
OnlinePredictor
]: a list of predictor
"""
return
self
.
predictors
...
...
tensorpack/tfutils/summary.py
View file @
4ee1e735
...
...
@@ -49,11 +49,14 @@ def create_image_summary(name, val):
val
=
val
.
astype
(
'uint8'
)
s
=
tf
.
Summary
()
for
k
in
range
(
n
):
arr
=
val
[
k
]
if
arr
.
shape
[
2
]
==
1
:
# scipy doesn't accept (h,w,1)
arr
=
arr
[:,
:,
0
]
tag
=
name
if
n
==
1
else
'{}/{}'
.
format
(
name
,
k
)
buf
=
io
.
BytesIO
()
# scipy assumes RGB
scipy
.
misc
.
toimage
(
val
[
k
]
)
.
save
(
buf
,
format
=
'png'
)
scipy
.
misc
.
toimage
(
arr
)
.
save
(
buf
,
format
=
'png'
)
img
=
tf
.
Summary
.
Image
()
img
.
height
=
h
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment