Commit bc1ba816 authored by Yuxin Wu's avatar Yuxin Wu

batch input in multithreadpredictor

parent c5da59af
...@@ -101,7 +101,8 @@ class SimulatorMaster(threading.Thread): ...@@ -101,7 +101,8 @@ class SimulatorMaster(threading.Thread):
ident, _, msg = self.socket.recv_multipart(zmq.NOBLOCK) ident, _, msg = self.socket.recv_multipart(zmq.NOBLOCK)
break break
except zmq.ZMQError: except zmq.ZMQError:
time.sleep(0.01) #pass
time.sleep(0.001)
#assert _ == "" #assert _ == ""
client = self.clients[ident] client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state client.protocol_state = 1 - client.protocol_state # first flip the state
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
import multiprocessing, threading import multiprocessing, threading
import tensorflow as tf import tensorflow as tf
import time
import six import six
from six.moves import queue, range from six.moves import queue, range, zip
from ..utils.concurrency import DIE from ..utils.concurrency import DIE
...@@ -89,10 +90,43 @@ class PredictorWorkerThread(threading.Thread): ...@@ -89,10 +90,43 @@ class PredictorWorkerThread(threading.Thread):
self.id = id self.id = id
def run(self): def run(self):
#self.xxx = None
def fetch():
batched = []
futures = []
inp, f = self.queue.get()
batched.append(inp)
futures.append(f)
while True:
try:
inp, f = self.queue.get_nowait()
batched.append(inp)
futures.append(f)
if len(batched) == 128:
break
except queue.Empty:
break
return batched, futures
#self.xxx = None
while True: while True:
inputs, f = self.queue.get() # normal input
outputs = self.func(inputs) #inputs, f = self.queue.get()
f.set_result(outputs) #outputs = self.func(inputs)
#f.set_result(outputs)
batched, futures = fetch()
#print "batched size: ", len(batched)
outputs = self.func([batched])
#if self.xxx is None:
#outputs = self.func([batched])
#self.xxx = outputs
#else:
#outputs = [None, None]
#outputs[0] = [self.xxx[0][0]] * len(batched)
#outputs[1] = [self.xxx[1][0]] * len(batched)
for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs])
class MultiThreadAsyncPredictor(object): class MultiThreadAsyncPredictor(object):
""" """
...@@ -117,7 +151,7 @@ class MultiThreadAsyncPredictor(object): ...@@ -117,7 +151,7 @@ class MultiThreadAsyncPredictor(object):
def put_task(self, inputs, callback=None): def put_task(self, inputs, callback=None):
""" return a Future of output.""" """ return a Future of output."""
f = Future() f = Future()
self.input_queue.put((inputs, f))
if callback is not None: if callback is not None:
f.add_done_callback(callback) f.add_done_callback(callback)
self.input_queue.put((inputs, f))
return f return f
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment