Commit 6e0b509b authored by Yuxin Wu's avatar Yuxin Wu

Catch exception in GPUUtilization.worker (#1134)

parent d04c8444
...@@ -75,6 +75,9 @@ class GPUUtilizationTracker(Callback): ...@@ -75,6 +75,9 @@ class GPUUtilizationTracker(Callback):
# Don't do this in after_epoch because # Don't do this in after_epoch because
# before,after_epoch are supposed to be extremely fast by design. # before,after_epoch are supposed to be extremely fast by design.
stats = self._queue.get() stats = self._queue.get()
if stats == -1:
from ..train.base import StopTraining
raise StopTraining("GPUUtilizationTracker.worker has failed.")
for idx, dev in enumerate(self._devices): for idx, dev in enumerate(self._devices):
self.trainer.monitors.put_scalar('GPUUtil/{}'.format(dev), stats[idx]) self.trainer.monitors.put_scalar('GPUUtil/{}'.format(dev), stats[idx])
...@@ -85,6 +88,7 @@ class GPUUtilizationTracker(Callback): ...@@ -85,6 +88,7 @@ class GPUUtilizationTracker(Callback):
def worker(self, evt, rst_queue, stop_evt): def worker(self, evt, rst_queue, stop_evt):
while True: while True:
try:
evt.wait() # start epoch evt.wait() # start epoch
evt.clear() evt.clear()
if stop_evt.is_set(): # or on exit if stop_evt.is_set(): # or on exit
...@@ -111,6 +115,10 @@ class GPUUtilizationTracker(Callback): ...@@ -111,6 +115,10 @@ class GPUUtilizationTracker(Callback):
cnt -= 1 cnt -= 1
rst_queue.put(stats / cnt) rst_queue.put(stats / cnt)
break break
except Exception:
logger.exception("Exception in GPUUtilizationTracker.worker")
rst_queue.put(-1)
return
# Can add more features from tfprof # Can add more features from tfprof
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment