Commit f4f41711 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by Yuxin Wu

Increase GPU counting robustness (#640)

* Increase GPU counting robustness

Added native Tensorflow method to check number of GPUs.

* docs

* missing import

* fix import
parent 1ce585be
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
import os import os
from .utils import change_env from .utils import change_env
from . import logger
from .concurrency import subproc_call from .concurrency import subproc_call
__all__ = ['change_gpu', 'get_nr_gpu'] __all__ = ['change_gpu', 'get_nr_gpu']
...@@ -30,7 +32,13 @@ def get_nr_gpu(): ...@@ -30,7 +32,13 @@ def get_nr_gpu():
if env is not None: if env is not None:
return len(env.split(',')) return len(env.split(','))
output, code = subproc_call("nvidia-smi -L", timeout=5) output, code = subproc_call("nvidia-smi -L", timeout=5)
if code != 0: if code == 0:
return 0 output = output.decode('utf-8')
output = output.decode('utf-8') return len(output.strip().split('\n'))
return len(output.strip().split('\n')) else:
# Note this will initialize all GPUs and therefore has side effect
# https://github.com/tensorflow/tensorflow/issues/8136
logger.info("Loading local devices by TensorFlow ...")
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
return len([x.name for x in local_device_protos if x.device_type == 'GPU'])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment