Commit d6f1b6ee authored by Yuxin Wu's avatar Yuxin Wu

bug fix in prefecthongpus

parent be2d2001
......@@ -13,7 +13,8 @@ import os
from .base import ProxyDataFlow
from ..utils.concurrency import *
from ..utils.serialize import loads, dumps
from ..utils import logger, change_env
from ..utils import logger
from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'BlockParallel']
try:
......@@ -173,8 +174,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
class PrefetchOnGPUs(PrefetchDataZMQ):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES"""
def __init__(self, ds, gpus, pipedir=None):
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
self.gpus = gpus
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
def start_processes(self):
with mask_sigint():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment