Commit 5667a220 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

add NVML binding (#642)

* add NVML binding

* fix flake8

* separate gpu and nvml

* remove stuff

* switch to context

* Clean-up and rename

* fix lint
parent 2ce7ec1c
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
import os import os
from .utils import change_env from .utils import change_env
from . import logger from . import logger
from .nvml import NVMLContext
from .concurrency import subproc_call from .concurrency import subproc_call
__all__ = ['change_gpu', 'get_nr_gpu'] __all__ = ['change_gpu', 'get_nr_gpu']
...@@ -36,9 +36,15 @@ def get_nr_gpu(): ...@@ -36,9 +36,15 @@ def get_nr_gpu():
output = output.decode('utf-8') output = output.decode('utf-8')
return len(output.strip().split('\n')) return len(output.strip().split('\n'))
else: else:
# Note this will initialize all GPUs and therefore has side effect try:
# https://github.com/tensorflow/tensorflow/issues/8136 # Use NVML to query device properties
logger.info("Loading local devices by TensorFlow ...") with NVMLContext() as ctx:
from tensorflow.python.client import device_lib return ctx.num_devices()
local_device_protos = device_lib.list_local_devices() except Exception:
return len([x.name for x in local_device_protos if x.device_type == 'GPU']) # Fallback
# Note this will initialize all GPUs and therefore has side effect
# https://github.com/tensorflow/tensorflow/issues/8136
logger.info("Loading local devices by TensorFlow ...")
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
return len([x.name for x in local_device_protos if x.device_type == 'GPU'])
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: nvml.py
from ctypes import (byref, c_uint, c_ulonglong,
CDLL, POINTER, Structure)
import threading
__all__ = ['NVMLContext']
NVML_ERROR_FUNCTION_NOT_FOUND = 13
NvmlErrorCodes = {"0": "NVML_SUCCESS",
"1": "NVML_ERROR_UNINITIALIZED",
"2": "NVML_ERROR_INVALID_ARGUMENT",
"3": "NVML_ERROR_NOT_SUPPORTED",
"4": "NVML_ERROR_NO_PERMISSION",
"5": "NVML_ERROR_ALREADY_INITIALIZED",
"6": "NVML_ERROR_NOT_FOUND",
"7": "NVML_ERROR_INSUFFICIENT_SIZE",
"8": "NVML_ERROR_INSUFFICIENT_POWER",
"9": "NVML_ERROR_DRIVER_NOT_LOADED",
"10": "NVML_ERROR_TIMEOUT",
"11": "NVML_ERROR_IRQ_ISSUE",
"12": "NVML_ERROR_LIBRARY_NOT_FOUND",
"13": "NVML_ERROR_FUNCTION_NOT_FOUND",
"14": "NVML_ERROR_CORRUPTED_INFOROM",
"15": "NVML_ERROR_GPU_IS_LOST",
"16": "NVML_ERROR_RESET_REQUIRED",
"17": "NVML_ERROR_OPERATING_SYSTEM",
"18": "NVML_ERROR_LIB_RM_VERSION_MISMATCH",
"999": "NVML_ERROR_UNKNOWN"}
class NvmlException(Exception):
def __init__(self, error_code):
super(NvmlException, self).__init__(error_code)
self.error_code = error_code
def __str__(self):
return NvmlErrorCodes[str(self.error_code)]
def _check_return(ret):
if (ret != 0):
raise NvmlException(ret)
return ret
class NVML(object):
"""
Loader for libnvidia-ml.so
"""
_nvmlLib = None
_lib_lock = threading.Lock()
def load(self):
with self._lib_lock:
if self._nvmlLib is None:
self._nvmlLib = CDLL("libnvidia-ml.so.1")
function_pointers = ["nvmlDeviceGetName", "nvmlDeviceGetUUID", "nvmlDeviceGetMemoryInfo",
"nvmlDeviceGetUtilizationRates", "nvmlInit_v2", "nvmlShutdown",
"nvmlDeviceGetCount_v2", "nvmlDeviceGetHandleByIndex_v2"]
self.func_ptr = {n: self._function_pointer(n) for n in function_pointers}
def _function_pointer(self, name):
try:
return getattr(self._nvmlLib, name)
except AttributeError:
raise NvmlException(NVML_ERROR_FUNCTION_NOT_FOUND)
def get_function(self, name):
if name in self.func_ptr.keys():
return self.func_ptr[name]
_NVML = NVML()
class NvidiaDevice(object):
"""Represent a single GPUDevice"""
def __init__(self, hnd):
super(NvidiaDevice, self).__init__()
self.hnd = hnd
def memory(self):
"""Memory information in bytes
Example:
>>> print(ctx.device(0).memory())
{'total': 4238016512L, 'used': 434831360L, 'free': 3803185152L}
Returns:
total/used/free memory in bytes
"""
class GpuMemoryInfo(Structure):
_fields_ = [
('total', c_ulonglong),
('free', c_ulonglong),
('used', c_ulonglong),
]
c_memory = GpuMemoryInfo()
_check_return(_NVML.get_function(
"nvmlDeviceGetMemoryInfo")(self.hnd, byref(c_memory)))
return {'total': c_memory.total, 'free': c_memory.free, 'used': c_memory.used}
def utilization(self):
"""Percent of time over the past second was utilized.
Details:
Percent of time over the past second during which one or more kernels was executing on the GPU.
Percent of time over the past second during which global (device) memory was being read or written
Example:
>>> print(ctx.device(0).utilization())
{'gpu': 4L, 'memory': 6L}
"""
class GpuUtilizationInfo(Structure):
_fields_ = [
('gpu', c_uint),
('memory', c_uint),
]
c_util = GpuUtilizationInfo()
_check_return(_NVML.get_function(
"nvmlDeviceGetUtilizationRates")(self.hnd, byref(c_util)))
return {'gpu': c_util.gpu, 'memory': c_util.memory}
class NVMLContext(object):
"""Creates a context to query information
Example:
with NVMLContext() as ctx:
num_gpus = ctx.num_devices()
for device in ctx.devices():
print(device.memory())
print(device.utilization())
"""
def __enter__(self):
"""Create a new context """
_NVML.load()
_check_return(_NVML.get_function("nvmlInit_v2")())
return self
def __exit__(self, type, value, tb):
"""Destroy current context"""
_check_return(_NVML.get_function("nvmlShutdown")())
def num_devices(self):
"""Get number of devices """
c_count = c_uint()
_check_return(_NVML.get_function(
"nvmlDeviceGetCount_v2")(byref(c_count)))
return c_count.value
def devices(self):
"""
Returns:
[NvidiaDevice]: a list of devices
"""
return [self.device(i) for i in range(self.num_devices())]
def device(self, idx):
"""Get a specific GPU device
Args:
idx: index of device
Returns:
NvidiaDevice: single GPU device
"""
class GpuDevice(Structure):
pass
c_nvmlDevice_t = POINTER(GpuDevice)
c_index = c_uint(idx)
device = c_nvmlDevice_t()
_check_return(_NVML.get_function(
"nvmlDeviceGetHandleByIndex_v2")(c_index, byref(device)))
return NvidiaDevice(device)
if __name__ == '__main__':
with NVMLContext() as ctx:
print(ctx.devices())
print(ctx.devices()[0].utilization())
with NVMLContext() as ctx:
print(ctx.devices())
print(ctx.devices()[0].utilization())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment