Commit e7aaaf13 authored by Yuxin Wu's avatar Yuxin Wu

GPU Util Tracker use CUDA_VISIBLE_DEVICES

parent a23a92d1
...@@ -35,16 +35,22 @@ class SendStat(Callback): ...@@ -35,16 +35,22 @@ class SendStat(Callback):
class GPUUtilizationTracker(Callback): class GPUUtilizationTracker(Callback):
""" Summarize the average GPU utilization within an epoch""" """ Summarize the average GPU utilization within an epoch"""
def __init__(self, devices): def __init__(self, devices=None):
""" """
Args: Args:
devices (list[int]): physical GPU ids devices (list[int]): physical GPU ids. If None, will use CUDA_VISIBLE_DEVICES
""" """
self._devices = list(map(str, devices)) if devices is None:
env = os.environ.get('CUDA_VISIBLE_DEVICES')
assert env is not None, "[GPUUtilizationTracker] Both devices and CUDA_VISIBLE_DEVICES are None!"
self._devices = env.split(',')
else:
self._devices = list(map(str, devices))
assert len(self._devices), "[GPUUtilizationTracker] No GPU device given!"
self._command = "nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader,nounits -i " + \ self._command = "nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader,nounits -i " + \
','.join(self._devices) ','.join(self._devices)
output, ret = subproc_call(self._command) _, ret = subproc_call(self._command)
assert ret == 0, "Cannot fetch GPU utilization!" assert ret == 0, "Cannot fetch GPU utilization!"
def _before_train(self): def _before_train(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment