Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
34a5a809
Commit
34a5a809
authored
Nov 29, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dataparallelmultitower
parent
93beba57
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
5 deletions
+40
-5
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+10
-3
tensorpack/predict/base.py
tensorpack/predict/base.py
+29
-2
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+1
-0
No files found.
tensorpack/models/model_desc.py
View file @
34a5a809
...
@@ -32,12 +32,19 @@ class ModelDesc(object):
...
@@ -32,12 +32,19 @@ class ModelDesc(object):
return
self
.
reuse_input_vars
()
return
self
.
reuse_input_vars
()
except
KeyError
:
except
KeyError
:
pass
pass
ret
=
self
.
get_placeholders
()
for
v
in
ret
:
tf
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
return
ret
def
get_placeholders
(
self
,
prefix
=
''
):
""" build placeholders with optional prefix, for each InputVar"""
input_vars
=
self
.
_get_input_vars
()
input_vars
=
self
.
_get_input_vars
()
ret
=
[]
ret
=
[]
for
v
in
input_vars
:
for
v
in
input_vars
:
ret
.
append
(
tf
.
placeholder
(
v
.
type
,
shape
=
v
.
shape
,
name
=
v
.
name
))
ret
.
append
(
tf
.
placeholder
(
for
v
in
ret
:
v
.
type
,
shape
=
v
.
shape
,
tf
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
name
=
prefix
+
v
.
name
)
)
return
ret
return
ret
def
reuse_input_vars
(
self
):
def
reuse_input_vars
(
self
):
...
...
tensorpack/predict/base.py
View file @
34a5a809
...
@@ -12,7 +12,8 @@ from ..tfutils import get_tensors_by_names, TowerContext
...
@@ -12,7 +12,8 @@ from ..tfutils import get_tensors_by_names, TowerContext
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
,
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
,
'AsyncPredictorBase'
,
'AsyncPredictorBase'
,
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
]
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
,
'DataParallelOfflinePredictor'
]
class
PredictorBase
(
object
):
class
PredictorBase
(
object
):
...
@@ -128,7 +129,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -128,7 +129,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for
k
in
towers
:
for
k
in
towers
:
output_vars
=
get_tensors_by_names
(
output_vars
=
get_tensors_by_names
(
[
'
{}{}/'
.
format
(
self
.
PREFIX
,
k
)
+
n
\
[
'
towerp{}/'
.
format
(
k
)
+
n
\
for
n
in
config
.
output_names
])
for
n
in
config
.
output_names
])
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
...
@@ -139,3 +140,29 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -139,3 +140,29 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def
get_predictors
(
self
,
n
):
def
get_predictors
(
self
,
n
):
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
def
__init__
(
self
,
config
,
towers
):
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
input_var_names
=
[]
for
k
in
towers
:
input_vars
=
config
.
model
.
get_placeholders
(
prefix
=
'towerp{}-'
.
format
(
k
))
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
'towerp{}'
.
format
(
k
)):
config
.
model
.
build_graph
(
input_vars
)
tf
.
get_variable_scope
()
.
reuse_variables
()
input_var_names
.
extend
([
k
.
name
for
k
in
input_vars
])
input_vars
=
get_tensors_by_names
(
input_var_names
)
config
.
session_init
.
init
(
sess
)
output_vars
=
[]
for
k
in
towers
:
output_vars
.
extend
(
get_tensors_by_names
(
[
'towerp{}/'
.
format
(
k
)
+
n
\
for
n
in
config
.
output_names
]))
super
(
DataParallelOfflinePredictor
,
self
)
.
__init__
(
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
tensorpack/predict/concurrency.py
View file @
34a5a809
...
@@ -46,6 +46,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
...
@@ -46,6 +46,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
from
tensorpack.models._common
import
disable_layer_logging
from
tensorpack.models._common
import
disable_layer_logging
disable_layer_logging
()
disable_layer_logging
()
self
.
predictor
=
OfflinePredictor
(
self
.
config
)
self
.
predictor
=
OfflinePredictor
(
self
.
config
)
import
sys
if
self
.
idx
==
0
:
if
self
.
idx
==
0
:
with
self
.
predictor
.
graph
.
as_default
():
with
self
.
predictor
.
graph
.
as_default
():
describe_model
()
describe_model
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment