Commit 75e8d9fe authored by Yuxin Wu's avatar Yuxin Wu

more generic parallel predict worker

parent 6e562112
......@@ -17,7 +17,8 @@ from .tfutils.modelutils import describe_model
from .dataflow import DataFlow, BatchData
from .dataflow.dftools import dataflow_to_process_queue
__all__ = ['PredictConfig', 'DatasetPredictor', 'get_predict_func']
__all__ = ['PredictConfig', 'DatasetPredictor', 'get_predict_func',
'ParallelPredictWorker']
PredictResult = namedtuple('PredictResult', ['input', 'output'])
......@@ -97,24 +98,19 @@ def get_predict_func(config):
return sess.run(output_vars, feed_dict=feed)
return run_input
class PredictWorker(multiprocessing.Process):
""" A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config):
class ParallelPredictWorker(multiprocessing.Process):
def __init__(self, idx, gpuid, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
"""
super(PredictWorker, self).__init__()
super(ParallelPredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.inqueue = inqueue
self.outqueue = outqueue
self.config = config
def run(self):
def _init_runtime(self):
if self.gpuid >= 0:
logger.info("Worker {} uses GPU {}".format(self.idx, self.gpuid))
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
......@@ -128,6 +124,24 @@ class PredictWorker(multiprocessing.Process):
self.func = get_predict_func(self.config)
if self.idx == 0:
describe_model()
class QueuePredictWorker(ParallelPredictWorker):
""" A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
"""
super(QueuePredictWorker, self).__init__(idx, gpuid, config)
self.inqueue = inqueue
self.outqueue = outqueue
def run(self):
self._init_runtime()
while True:
tid, dp = self.inqueue.get()
if tid == DIE:
......@@ -156,7 +170,7 @@ class DatasetPredictor(object):
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
except KeyError:
gpus = list(range(self.nr_gpu))
self.workers = [PredictWorker(i, gpus[i], self.inqueue, self.outqueue, config)
self.workers = [QueuePredictWorker(i, gpus[i], self.inqueue, self.outqueue, config)
for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment