Commit 75e8d9fe authored by Yuxin Wu's avatar Yuxin Wu

more generic parallel predict worker

parent 6e562112
...@@ -17,7 +17,8 @@ from .tfutils.modelutils import describe_model ...@@ -17,7 +17,8 @@ from .tfutils.modelutils import describe_model
from .dataflow import DataFlow, BatchData from .dataflow import DataFlow, BatchData
from .dataflow.dftools import dataflow_to_process_queue from .dataflow.dftools import dataflow_to_process_queue
__all__ = ['PredictConfig', 'DatasetPredictor', 'get_predict_func'] __all__ = ['PredictConfig', 'DatasetPredictor', 'get_predict_func',
'ParallelPredictWorker']
PredictResult = namedtuple('PredictResult', ['input', 'output']) PredictResult = namedtuple('PredictResult', ['input', 'output'])
...@@ -97,24 +98,19 @@ def get_predict_func(config): ...@@ -97,24 +98,19 @@ def get_predict_func(config):
return sess.run(output_vars, feed_dict=feed) return sess.run(output_vars, feed_dict=feed)
return run_input return run_input
class PredictWorker(multiprocessing.Process): class ParallelPredictWorker(multiprocessing.Process):
""" A worker process to run predictor on one GPU """ def __init__(self, idx, gpuid, config):
def __init__(self, idx, gpuid, inqueue, outqueue, config):
""" """
:param idx: index of the worker. the 0th worker will print log. :param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU. :param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig` :param config: a `PredictConfig`
""" """
super(PredictWorker, self).__init__() super(ParallelPredictWorker, self).__init__()
self.idx = idx self.idx = idx
self.gpuid = gpuid self.gpuid = gpuid
self.inqueue = inqueue
self.outqueue = outqueue
self.config = config self.config = config
def run(self): def _init_runtime(self):
if self.gpuid >= 0: if self.gpuid >= 0:
logger.info("Worker {} uses GPU {}".format(self.idx, self.gpuid)) logger.info("Worker {} uses GPU {}".format(self.idx, self.gpuid))
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
...@@ -128,6 +124,24 @@ class PredictWorker(multiprocessing.Process): ...@@ -128,6 +124,24 @@ class PredictWorker(multiprocessing.Process):
self.func = get_predict_func(self.config) self.func = get_predict_func(self.config)
if self.idx == 0: if self.idx == 0:
describe_model() describe_model()
class QueuePredictWorker(ParallelPredictWorker):
""" A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
"""
super(QueuePredictWorker, self).__init__(idx, gpuid, config)
self.inqueue = inqueue
self.outqueue = outqueue
def run(self):
self._init_runtime()
while True: while True:
tid, dp = self.inqueue.get() tid, dp = self.inqueue.get()
if tid == DIE: if tid == DIE:
...@@ -156,7 +170,7 @@ class DatasetPredictor(object): ...@@ -156,7 +170,7 @@ class DatasetPredictor(object):
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
except KeyError: except KeyError:
gpus = list(range(self.nr_gpu)) gpus = list(range(self.nr_gpu))
self.workers = [PredictWorker(i, gpus[i], self.inqueue, self.outqueue, config) self.workers = [QueuePredictWorker(i, gpus[i], self.inqueue, self.outqueue, config)
for i in range(self.nr_gpu)] for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue) self.result_queue = OrderedResultGatherProc(self.outqueue)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment