Commit 4a1af743 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by Yuxin Wu

MultiProcessDatasetPredictor fix for 1 GPU (#939)

Updated MultiProcessDatasetPredictor so that it runs on one GPU (allows for async dataset processing). Also properly checks the number of GPUs if CUDA_VISIBLE_DEVICES is not defined.
parent 06bc1c5d
......@@ -13,7 +13,7 @@ from ..dataflow.dftools import dump_dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.gpu import change_gpu
from ..utils.gpu import change_gpu, get_num_gpu
from .concurrency import MultiProcessQueuePredictWorker
from .config import PredictConfig
......@@ -99,7 +99,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
"""
if config.return_input:
logger.warn("Using the option `return_input` in MultiProcessDatasetPredictor might be slow")
assert nr_proc > 1, nr_proc
assert nr_proc >= 1, nr_proc
super(MultiProcessDatasetPredictor, self).__init__(config, dataset)
self.nr_proc = nr_proc
......@@ -111,12 +111,11 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
if use_gpu:
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
assert len(gpus) >= self.nr_proc, \
except KeyError:
gpus = list(range(get_num_gpu()))
assert len(gpus) >= self.nr_proc, \
"nr_proc={} while only {} gpus available".format(
self.nr_proc, len(gpus))
except KeyError:
# TODO number of GPUs not checked
gpus = list(range(self.nr_proc))
else:
gpus = ['-1'] * self.nr_proc
# worker produces (idx, result) to outqueue
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment