Commit efbf256e authored by Yuxin Wu's avatar Yuxin Wu

fix multiprocess datasetpredictor

parent 95b6437a
...@@ -59,8 +59,8 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -59,8 +59,8 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
super(MultiProcessQueuePredictWorker, self).__init__(idx, config) super(MultiProcessQueuePredictWorker, self).__init__(idx, config)
self.inqueue = inqueue self.inqueue = inqueue
self.outqueue = outqueue self.outqueue = outqueue
assert isinstance(self.inqueue, multiprocessing.Queue) assert isinstance(self.inqueue, multiprocessing.queues.Queue)
assert isinstance(self.outqueue, multiprocessing.Queue) assert isinstance(self.outqueue, multiprocessing.queues.Queue)
def run(self): def run(self):
self._init_runtime() self._init_runtime()
......
...@@ -12,6 +12,8 @@ import os ...@@ -12,6 +12,8 @@ import os
from ..dataflow import DataFlow, BatchData from ..dataflow import DataFlow, BatchData
from ..dataflow.dftools import dataflow_to_process_queue from ..dataflow.dftools import dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger
from ..utils.gpu import change_gpu
from .concurrency import MultiProcessQueuePredictWorker from .concurrency import MultiProcessQueuePredictWorker
from .common import PredictConfig from .common import PredictConfig
...@@ -89,9 +91,9 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -89,9 +91,9 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
# TODO number of GPUs not checked # TODO number of GPUs not checked
gpus = list(range(self.nr_proc)) gpus = list(range(self.nr_proc))
else: else:
gpus = [''] * self.nr_proc gpus = ['-1'] * self.nr_proc
self.workers = [MultiProcessQueuePredictWorker( self.workers = [MultiProcessQueuePredictWorker(
i, gpus[i], self.inqueue, self.outqueue, self.config) i, self.inqueue, self.outqueue, self.config)
for i in range(self.nr_proc)] for i in range(self.nr_proc)]
self.result_queue = OrderedResultGatherProc( self.result_queue = OrderedResultGatherProc(
self.outqueue, nr_producer=self.nr_proc) self.outqueue, nr_producer=self.nr_proc)
......
...@@ -9,7 +9,10 @@ from .utils import change_env ...@@ -9,7 +9,10 @@ from .utils import change_env
__all__ = ['change_gpu', 'get_nr_gpu', 'get_gpus'] __all__ = ['change_gpu', 'get_nr_gpu', 'get_gpus']
def change_gpu(val): def change_gpu(val):
return change_env('CUDA_VISIBLE_DEVICES', str(val)) val = str(val)
if val == '-1':
val = ''
return change_env('CUDA_VISIBLE_DEVICES', val)
def get_nr_gpu(): def get_nr_gpu():
env = os.environ.get('CUDA_VISIBLE_DEVICES', None) env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment