Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
80722088
Commit
80722088
authored
May 29, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
initial version of multithread predictor
parent
5ccaea83
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
79 additions
and
30 deletions
+79
-30
opt-requirements.txt
opt-requirements.txt
+1
-0
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+52
-16
tensorpack/train/base.py
tensorpack/train/base.py
+9
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+17
-13
No files found.
opt-requirements.txt
View file @
80722088
...
...
@@ -2,3 +2,4 @@ nltk
h5py
pyzmq
subprocess32
tornado
tensorpack/predict/concurrency.py
View file @
80722088
...
...
@@ -5,6 +5,9 @@
import
multiprocessing
,
threading
import
tensorflow
as
tf
from
six.moves
import
queue
,
range
from
..utils.concurrency
import
DIE
from
..tfutils.modelutils
import
describe_model
from
..utils
import
logger
...
...
@@ -12,10 +15,20 @@ from ..tfutils import *
from
.common
import
*
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
]
try
:
if
six
.
PY2
:
from
tornado.concurrent
import
Future
else
:
from
concurrent.futures
import
Future
except
ImportError
:
logger
.
warn
(
"Cannot import Future in either tornado.concurrent or py3 standard lib. MultiThreadAsyncPredictor won't be available."
)
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
]
else
:
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
'MultiThreadAsyncPredictor'
]
class
MultiProcessPredictWorker
(
multiprocessing
.
Process
):
""" Base class for predict worker that runs in multiprocess"""
""" Base class for predict worker that runs
offline
in multiprocess"""
def
__init__
(
self
,
idx
,
gpuid
,
config
):
"""
:param idx: index of the worker. the 0th worker will print log.
...
...
@@ -44,7 +57,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
describe_model
()
class
MultiProcessQueuePredictWorker
(
MultiProcessPredictWorker
):
""" A
worker process to run predictor on one GPU
"""
""" A
predictor worker that takes input and produces output by queue
"""
def
__init__
(
self
,
idx
,
gpuid
,
inqueue
,
outqueue
,
config
):
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
...
...
@@ -64,17 +77,40 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else
:
self
.
outqueue
.
put
((
tid
,
self
.
func
(
dp
)))
#class CurrentSessionPredictor():
#def __init__(self, idx, gpuid, config):
#"""
#:param idx: index of the worker. the 0th worker will print log.
#:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
#:param config: a `PredictConfig`
#"""
#super(MultiProcessPredictWorker, self).__init__()
#self.idx = idx
#self.gpuid = gpuid
#self.config = config
class
PerdictorWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
queue
,
pred_func
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
self
.
queue
=
queue
self
.
func
=
pred_func
self
.
daemon
=
True
def
run
(
self
):
while
True
:
inputs
,
f
=
self
.
queue
.
get
()
outputs
=
self
.
func
(
inputs
)
f
.
set_result
(
outputs
)
class
MultiThreadAsyncPredictor
(
object
):
"""
An online predictor (use the current active session) that works with
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU.
"""
def
__init__
(
self
,
trainer
,
input_names
,
output_names
,
nr_thread
):
"""
:param trainer: a `QueueInputTrainer` instance.
"""
self
.
input_queue
=
queue
.
Queue
()
self
.
threads
=
[
PredictorWorkerThread
(
self
.
input_queue
,
f
)
for
f
in
trainer
.
get_predict_funcs
(
input_names
,
output_names
,
nr_thread
)]
def
run
(
self
):
for
t
in
self
.
threads
:
t
.
start
()
#def run(self):
#pass
def
put_task
(
self
,
inputs
,
callback
=
None
):
""" return a Future of output."""
f
=
Future
()
self
.
input_queue
.
put
((
inputs
,
f
))
if
callback
is
not
None
:
f
.
add_done_callback
(
callback
)
return
f
tensorpack/train/base.py
View file @
80722088
...
...
@@ -52,9 +52,17 @@ class Trainer(object):
@
abstractmethod
def
get_predict_func
(
self
,
input_names
,
output_names
):
""" return a predict function"""
""" return a predict
or
function"""
pass
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
""" return n predictor functions.
Can be overwritten by subclasses to exploit more
parallelism among funcs.
"""
return
[
self
.
get_predict_func
(
input_name
,
output_names
)
for
k
in
range
(
n
)]
def
trigger_epoch
(
self
):
self
.
_trigger_epoch
()
self
.
config
.
callbacks
.
trigger_epoch
()
...
...
tensorpack/train/trainer.py
View file @
80722088
...
...
@@ -117,7 +117,7 @@ class QueueInputTrainer(Trainer):
self
.
async
=
async
if
self
.
async
:
assert
self
.
config
.
nr_tower
>
1
self
.
_
dequed_inputs
=
[]
self
.
dequed_inputs
=
[]
@
staticmethod
def
_average_grads
(
tower_grads
):
...
...
@@ -140,7 +140,7 @@ class QueueInputTrainer(Trainer):
assert
len
(
ret
)
==
len
(
self
.
input_vars
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
self
.
_
dequed_inputs
.
append
(
ret
)
self
.
dequed_inputs
.
append
(
ret
)
return
ret
def
_single_tower_grad
(
self
):
...
...
@@ -241,27 +241,31 @@ class QueueInputTrainer(Trainer):
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
def
get_predict_func
(
self
,
input_names
,
output_names
):
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
:param tower: return the kth predict_func
"""
tower
=
tower
%
self
.
config
.
nr_tower
logger
.
info
(
"Prepare a predictor function for tower{} ..."
.
format
(
tower
))
raw_input_vars
=
get_vars_by_names
(
input_names
)
input_var_idxs
=
[
self
.
input_vars
.
index
(
v
)
for
v
in
raw_input_vars
]
if
self
.
config
.
nr_tower
==
1
:
dequed
=
self
.
_dequed_inputs
[
0
]
input_vars
=
[
dequed
[
k
]
for
k
in
input_var_idxs
]
output_vars
=
get_vars_by_names
(
output_names
)
else
:
# TODO naive impl: use the first tower only
dequed
=
self
.
_dequed_inputs
[
0
]
dequed
=
self
.
dequed_inputs
[
tower
]
input_vars
=
[
dequed
[
k
]
for
k
in
input_var_idxs
]
output_names
=
[
'tower0/'
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
if
self
.
config
.
nr_tower
>
1
:
output_names
=
[
'tower{}/'
.
format
(
tower
)
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
input_vars
)
feed
=
dict
(
zip
(
input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
return
[
self
.
get_predict_func
(
input_name
,
output_names
,
k
)
for
k
in
range
(
n
)]
def
start_train
(
config
):
tr
=
QueueInputTrainer
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment