Commit 4000f5d5 authored by Yuxin Wu's avatar Yuxin Wu

fix undefined names in multithreadasyncpredictor

parent 32ea8a29
......@@ -83,13 +83,12 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.func(dp)))
class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id, nr_input_var, batch_size=5):
def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__()
self.queue = queue
self.func = pred_func
self.daemon = True
self.batch_size = batch_size
self.nr_input_var = nr_input_var
self.id = id
def run(self):
......@@ -109,16 +108,17 @@ class PredictorWorkerThread(threading.Thread):
def fetch_batch(self):
""" Fetch a batch of data without waiting"""
batched, futures = [[] for _ in range(self.nr_input_var)], []
inp, f = self.queue.get()
for k in range(self.nr_input_var):
nr_input_var = len(inp)
batched, futures = [[] for _ in range(nr_input_var)], []
for k in range(nr_input_var):
batched[k].append(inp[k])
futures.append(f)
cnt = 1
while cnt < self.batch_size:
try:
inp, f = self.queue.get_nowait()
for k in range(self.nr_input_var):
for k in range(nr_input_var):
batched[k].append(inp[k])
futures.append(f)
except queue.Empty:
......@@ -133,11 +133,10 @@ class MultiThreadAsyncPredictor(object):
"""
def __init__(self, funcs, batch_size=5):
""" :param funcs: a list of predict func"""
self.input_queue = queue.Queue(maxsize=nr_thread*10)
self.input_queue = queue.Queue(maxsize=len(funcs)*10)
self.threads = [
PredictorWorkerThread(
self.input_queue, f, id,
len(input_names), batch_size=batch_size)
self.input_queue, f, id, batch_size=batch_size)
for id, f in enumerate(funcs)]
# TODO XXX set logging here to avoid affecting TF logging
......
......@@ -113,9 +113,10 @@ class QueueInputTrainer(Trainer):
"""
super(QueueInputTrainer, self).__init__(config)
self.input_vars = self.model.get_input_vars()
# use a smaller queue size for now, to avoid https://github.com/tensorflow/tensorflow/issues/2942
if input_queue is None:
self.input_queue = tf.FIFOQueue(
100, [x.dtype for x in self.input_vars], name='input_queue')
30, [x.dtype for x in self.input_vars], name='input_queue')
else:
self.input_queue = input_queue
if predict_tower is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment