Commit 4000f5d5 authored by Yuxin Wu's avatar Yuxin Wu

fix undefined names in multithreadasyncpredictor

parent 32ea8a29
...@@ -83,13 +83,12 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -83,13 +83,12 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.func(dp))) self.outqueue.put((tid, self.func(dp)))
class PredictorWorkerThread(threading.Thread): class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id, nr_input_var, batch_size=5): def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__() super(PredictorWorkerThread, self).__init__()
self.queue = queue self.queue = queue
self.func = pred_func self.func = pred_func
self.daemon = True self.daemon = True
self.batch_size = batch_size self.batch_size = batch_size
self.nr_input_var = nr_input_var
self.id = id self.id = id
def run(self): def run(self):
...@@ -109,16 +108,17 @@ class PredictorWorkerThread(threading.Thread): ...@@ -109,16 +108,17 @@ class PredictorWorkerThread(threading.Thread):
def fetch_batch(self): def fetch_batch(self):
""" Fetch a batch of data without waiting""" """ Fetch a batch of data without waiting"""
batched, futures = [[] for _ in range(self.nr_input_var)], []
inp, f = self.queue.get() inp, f = self.queue.get()
for k in range(self.nr_input_var): nr_input_var = len(inp)
batched, futures = [[] for _ in range(nr_input_var)], []
for k in range(nr_input_var):
batched[k].append(inp[k]) batched[k].append(inp[k])
futures.append(f) futures.append(f)
cnt = 1 cnt = 1
while cnt < self.batch_size: while cnt < self.batch_size:
try: try:
inp, f = self.queue.get_nowait() inp, f = self.queue.get_nowait()
for k in range(self.nr_input_var): for k in range(nr_input_var):
batched[k].append(inp[k]) batched[k].append(inp[k])
futures.append(f) futures.append(f)
except queue.Empty: except queue.Empty:
...@@ -133,11 +133,10 @@ class MultiThreadAsyncPredictor(object): ...@@ -133,11 +133,10 @@ class MultiThreadAsyncPredictor(object):
""" """
def __init__(self, funcs, batch_size=5): def __init__(self, funcs, batch_size=5):
""" :param funcs: a list of predict func""" """ :param funcs: a list of predict func"""
self.input_queue = queue.Queue(maxsize=nr_thread*10) self.input_queue = queue.Queue(maxsize=len(funcs)*10)
self.threads = [ self.threads = [
PredictorWorkerThread( PredictorWorkerThread(
self.input_queue, f, id, self.input_queue, f, id, batch_size=batch_size)
len(input_names), batch_size=batch_size)
for id, f in enumerate(funcs)] for id, f in enumerate(funcs)]
# TODO XXX set logging here to avoid affecting TF logging # TODO XXX set logging here to avoid affecting TF logging
......
...@@ -113,9 +113,10 @@ class QueueInputTrainer(Trainer): ...@@ -113,9 +113,10 @@ class QueueInputTrainer(Trainer):
""" """
super(QueueInputTrainer, self).__init__(config) super(QueueInputTrainer, self).__init__(config)
self.input_vars = self.model.get_input_vars() self.input_vars = self.model.get_input_vars()
# use a smaller queue size for now, to avoid https://github.com/tensorflow/tensorflow/issues/2942
if input_queue is None: if input_queue is None:
self.input_queue = tf.FIFOQueue( self.input_queue = tf.FIFOQueue(
100, [x.dtype for x in self.input_vars], name='input_queue') 30, [x.dtype for x in self.input_vars], name='input_queue')
else: else:
self.input_queue = input_queue self.input_queue = input_queue
if predict_tower is None: if predict_tower is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment