Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8a80b036
Commit
8a80b036
authored
Jul 14, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove PredictorTowerBuilder
parent
7699fd9b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
79 deletions
+1
-79
tensorpack/predict/base.py
tensorpack/predict/base.py
+1
-79
No files found.
tensorpack/predict/base.py
View file @
8a80b036
...
@@ -7,17 +7,11 @@ from abc import abstractmethod, ABCMeta
...
@@ -7,17 +7,11 @@ from abc import abstractmethod, ABCMeta
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
import
six
from
..utils
import
logger
from
..tfutils.common
import
get_tensors_by_names
from
..utils.argtools
import
memoized
from
..utils.naming
import
TOWER_FREEZE_KEYS
from
..tfutils.common
import
get_tensors_by_names
,
get_op_tensor_name
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
from
..tfutils.collection
import
freeze_collection
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
'OnlinePredictor'
,
'OfflinePredictor'
,
'OnlinePredictor'
,
'OfflinePredictor'
,
'PredictorTowerBuilder'
,
'build_prediction_graph'
,
]
]
...
@@ -144,75 +138,3 @@ class OfflinePredictor(OnlinePredictor):
...
@@ -144,75 +138,3 @@ class OfflinePredictor(OnlinePredictor):
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
super
(
OfflinePredictor
,
self
)
.
__init__
(
super
(
OfflinePredictor
,
self
)
.
__init__
(
input_tensors
,
output_tensors
,
config
.
return_input
,
sess
)
input_tensors
,
output_tensors
,
config
.
return_input
,
sess
)
class
PredictorTowerBuilder
(
object
):
"""
A builder which caches the predictor tower it has built.
"""
def
__init__
(
self
,
build_tower_fn
,
prefix
=
''
):
"""
Args:
build_tower_fn: a function that will be called inside each tower, taking tower id as the argument.
prefix: an extra prefix in tower name. The final tower prefix will be
determined by :meth:`TowerContext.get_predict_tower_name`.
"""
self
.
_fn
=
build_tower_fn
self
.
_prefix
=
prefix
@
memoized
def
build
(
self
,
tower
):
"""
Args:
tower (int): the tower will be built on device '/gpu:{tower}', or
'/cpu:0' if tower is -1.
"""
towername
=
TowerContext
.
get_predict_tower_name
(
tower
,
self
.
_prefix
)
if
self
.
_prefix
:
msg
=
"Building predictor graph {} on gpu={} with prefix='{}' ..."
.
format
(
towername
,
tower
,
self
.
_prefix
)
else
:
msg
=
"Building predictor graph {} on gpu={} ..."
.
format
(
towername
,
tower
)
logger
.
info
(
msg
)
# No matter where this get called, clear any existing name scope.
device
=
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
TOWER_FREEZE_KEYS
),
\
tf
.
device
(
device
),
\
TowerContext
(
towername
,
is_training
=
False
):
self
.
_fn
(
tower
)
# useful only when the placeholders don't have tower prefix
# note that in DataParallel predictor, placeholders do have tower prefix
@
staticmethod
def
get_tensors_maybe_in_tower
(
placeholder_names
,
names
,
tower
,
prefix
=
''
):
"""
Args:
placeholders (list): A list of __op__ name.
tower (int): relative GPU id.
"""
def
maybe_inside_tower
(
name
):
name
=
get_op_tensor_name
(
name
)[
0
]
if
name
in
placeholder_names
:
return
name
else
:
# if the name is not a placeholder, use it's name in each tower
return
TowerContext
.
get_predict_tower_name
(
tower
,
prefix
)
+
'/'
+
name
names
=
list
(
map
(
maybe_inside_tower
,
names
))
tensors
=
get_tensors_by_names
(
names
)
return
tensors
def
build_prediction_graph
(
build_tower_fn
,
towers
=
[
0
],
prefix
=
''
):
"""
Execute `build_tower_fn` on each tower.
Just a wrapper on :class:`PredictorTowerBuilder` to run on several towers
together.
"""
builder
=
PredictorTowerBuilder
(
build_tower_fn
,
prefix
)
for
idx
,
t
in
enumerate
(
towers
):
# The first variable scope may or may not reuse (depending on the existing
# context), but the rest have to reuse.
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
if
idx
>
0
else
None
):
builder
.
build
(
t
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment