Commit 6e0b509b authored by Yuxin Wu's avatar Yuxin Wu

Catch exception in GPUUtilization.worker (#1134)

parent d04c8444
...@@ -75,6 +75,9 @@ class GPUUtilizationTracker(Callback): ...@@ -75,6 +75,9 @@ class GPUUtilizationTracker(Callback):
# Don't do this in after_epoch because # Don't do this in after_epoch because
# before,after_epoch are supposed to be extremely fast by design. # before,after_epoch are supposed to be extremely fast by design.
stats = self._queue.get() stats = self._queue.get()
if stats == -1:
from ..train.base import StopTraining
raise StopTraining("GPUUtilizationTracker.worker has failed.")
for idx, dev in enumerate(self._devices): for idx, dev in enumerate(self._devices):
self.trainer.monitors.put_scalar('GPUUtil/{}'.format(dev), stats[idx]) self.trainer.monitors.put_scalar('GPUUtil/{}'.format(dev), stats[idx])
...@@ -85,33 +88,38 @@ class GPUUtilizationTracker(Callback): ...@@ -85,33 +88,38 @@ class GPUUtilizationTracker(Callback):
def worker(self, evt, rst_queue, stop_evt): def worker(self, evt, rst_queue, stop_evt):
while True: while True:
evt.wait() # start epoch try:
evt.clear() evt.wait() # start epoch
if stop_evt.is_set(): # or on exit evt.clear()
if stop_evt.is_set(): # or on exit
return
stats = np.zeros((len(self._devices),), dtype='f4')
cnt = 0
with NVMLContext() as ctx:
while True:
time.sleep(1)
data = [ctx.device(i).utilization()['gpu'] for i in self._devices]
data = list(map(float, data))
stats += data
cnt += 1
if evt.is_set(): # stop epoch
if stop_evt.is_set(): # or on exit
return
evt.clear()
if cnt > 1:
# Ignore the last datapoint. Usually is zero, makes us underestimate the util.
stats -= data
cnt -= 1
rst_queue.put(stats / cnt)
break
except Exception:
logger.exception("Exception in GPUUtilizationTracker.worker")
rst_queue.put(-1)
return return
stats = np.zeros((len(self._devices),), dtype='f4')
cnt = 0
with NVMLContext() as ctx:
while True:
time.sleep(1)
data = [ctx.device(i).utilization()['gpu'] for i in self._devices]
data = list(map(float, data))
stats += data
cnt += 1
if evt.is_set(): # stop epoch
if stop_evt.is_set(): # or on exit
return
evt.clear()
if cnt > 1:
# Ignore the last datapoint. Usually is zero, makes us underestimate the util.
stats -= data
cnt -= 1
rst_queue.put(stats / cnt)
break
# Can add more features from tfprof # Can add more features from tfprof
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/README.md # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/README.md
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment