Commit d6f1b6ee authored by Yuxin Wu's avatar Yuxin Wu

bug fix in prefecthongpus

parent be2d2001
...@@ -13,7 +13,8 @@ import os ...@@ -13,7 +13,8 @@ import os
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils.concurrency import * from ..utils.concurrency import *
from ..utils.serialize import loads, dumps from ..utils.serialize import loads, dumps
from ..utils import logger, change_env from ..utils import logger
from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'BlockParallel'] __all__ = ['PrefetchData', 'BlockParallel']
try: try:
...@@ -173,8 +174,8 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -173,8 +174,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
class PrefetchOnGPUs(PrefetchDataZMQ): class PrefetchOnGPUs(PrefetchDataZMQ):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES""" """ Prefetch with each process having a specific CUDA_VISIBLE_DEVICES"""
def __init__(self, ds, gpus, pipedir=None): def __init__(self, ds, gpus, pipedir=None):
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
self.gpus = gpus self.gpus = gpus
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
def start_processes(self): def start_processes(self):
with mask_sigint(): with mask_sigint():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment