Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bc1ba816
Commit
bc1ba816
authored
Jun 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
batch input in multithreadpredictor
parent
c5da59af
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
6 deletions
+41
-6
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+2
-1
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+39
-5
No files found.
tensorpack/RL/simulator.py
View file @
bc1ba816
...
...
@@ -101,7 +101,8 @@ class SimulatorMaster(threading.Thread):
ident
,
_
,
msg
=
self
.
socket
.
recv_multipart
(
zmq
.
NOBLOCK
)
break
except
zmq
.
ZMQError
:
time
.
sleep
(
0.01
)
#pass
time
.
sleep
(
0.001
)
#assert _ == ""
client
=
self
.
clients
[
ident
]
client
.
protocol_state
=
1
-
client
.
protocol_state
# first flip the state
...
...
tensorpack/predict/concurrency.py
View file @
bc1ba816
...
...
@@ -5,8 +5,9 @@
import
multiprocessing
,
threading
import
tensorflow
as
tf
import
time
import
six
from
six.moves
import
queue
,
range
from
six.moves
import
queue
,
range
,
zip
from
..utils.concurrency
import
DIE
...
...
@@ -89,10 +90,43 @@ class PredictorWorkerThread(threading.Thread):
self
.
id
=
id
def
run
(
self
):
#self.xxx = None
def
fetch
():
batched
=
[]
futures
=
[]
inp
,
f
=
self
.
queue
.
get
()
batched
.
append
(
inp
)
futures
.
append
(
f
)
while
True
:
try
:
inp
,
f
=
self
.
queue
.
get_nowait
()
batched
.
append
(
inp
)
futures
.
append
(
f
)
if
len
(
batched
)
==
128
:
break
except
queue
.
Empty
:
break
return
batched
,
futures
#self.xxx = None
while
True
:
inputs
,
f
=
self
.
queue
.
get
()
outputs
=
self
.
func
(
inputs
)
f
.
set_result
(
outputs
)
# normal input
#inputs, f = self.queue.get()
#outputs = self.func(inputs)
#f.set_result(outputs)
batched
,
futures
=
fetch
()
#print "batched size: ", len(batched)
outputs
=
self
.
func
([
batched
])
#if self.xxx is None:
#outputs = self.func([batched])
#self.xxx = outputs
#else:
#outputs = [None, None]
#outputs[0] = [self.xxx[0][0]] * len(batched)
#outputs[1] = [self.xxx[1][0]] * len(batched)
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
class
MultiThreadAsyncPredictor
(
object
):
"""
...
...
@@ -117,7 +151,7 @@ class MultiThreadAsyncPredictor(object):
def
put_task
(
self
,
inputs
,
callback
=
None
):
""" return a Future of output."""
f
=
Future
()
self
.
input_queue
.
put
((
inputs
,
f
))
if
callback
is
not
None
:
f
.
add_done_callback
(
callback
)
self
.
input_queue
.
put
((
inputs
,
f
))
return
f
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment