Commit f831a46e authored by Yuxin Wu's avatar Yuxin Wu

Do not make unnecessary calls to NVML (#1134)

parent 6f02650d
...@@ -67,7 +67,7 @@ class GPUUtilizationTracker(Callback): ...@@ -67,7 +67,7 @@ class GPUUtilizationTracker(Callback):
self._evt.set() self._evt.set()
def _after_epoch(self): def _after_epoch(self):
while self._evt.is_set(): # unlikely while self._evt.is_set(): # unlikely, unless the epoch is extremely fast
pass pass
self._evt.set() self._evt.set()
...@@ -87,6 +87,8 @@ class GPUUtilizationTracker(Callback): ...@@ -87,6 +87,8 @@ class GPUUtilizationTracker(Callback):
self._proc.terminate() self._proc.terminate()
def worker(self, evt, rst_queue, stop_evt): def worker(self, evt, rst_queue, stop_evt):
with NVMLContext() as ctx:
devices = [ctx.device(i) for i in self._devices]
while True: while True:
try: try:
evt.wait() # start epoch evt.wait() # start epoch
...@@ -96,11 +98,10 @@ class GPUUtilizationTracker(Callback): ...@@ -96,11 +98,10 @@ class GPUUtilizationTracker(Callback):
stats = np.zeros((len(self._devices),), dtype='f4') stats = np.zeros((len(self._devices),), dtype='f4')
cnt = 0 cnt = 0
with NVMLContext() as ctx:
while True: while True:
time.sleep(1) time.sleep(1)
data = [ctx.device(i).utilization()['gpu'] for i in self._devices] data = [d.utilization()['gpu'] for d in devices]
data = list(map(float, data)) data = list(map(float, data))
stats += data stats += data
cnt += 1 cnt += 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment