Commit cefc7748 authored by Yuxin Wu's avatar Yuxin Wu Committed by GitHub

Add ThroughputTracker (#1201)

parent 3036e824
...@@ -86,6 +86,7 @@ if __name__ == '__main__': ...@@ -86,6 +86,7 @@ if __name__ == '__main__':
ScheduledHyperParamSetter('learning_rate', lr_schedule), ScheduledHyperParamSetter('learning_rate', lr_schedule),
GPUMemoryTracker(), GPUMemoryTracker(),
HostMemoryTracker(), HostMemoryTracker(),
ThroughputTracker(samples_per_epoch=cfg.TRAIN.NUM_GPUS),
EstimatedTimeLeft(median=True), EstimatedTimeLeft(median=True),
SessionRunTimeout(60000), # 1 minute timeout SessionRunTimeout(60000), # 1 minute timeout
] ]
......
...@@ -13,12 +13,14 @@ import psutil ...@@ -13,12 +13,14 @@ import psutil
from ..tfutils.common import gpu_available_in_session from ..tfutils.common import gpu_available_in_session
from ..utils import logger from ..utils import logger
from ..utils.timer import Timer
from ..utils.concurrency import ensure_proc_terminate, start_proc_mask_signal from ..utils.concurrency import ensure_proc_terminate, start_proc_mask_signal
from ..utils.gpu import get_num_gpu from ..utils.gpu import get_num_gpu
from ..utils.nvml import NVMLContext from ..utils.nvml import NVMLContext
from .base import Callback from .base import Callback
__all__ = ['GPUUtilizationTracker', 'GraphProfiler', 'PeakMemoryTracker', 'GPUMemoryTracker', 'HostMemoryTracker'] __all__ = ['GPUUtilizationTracker', 'GraphProfiler', 'PeakMemoryTracker',
'GPUMemoryTracker', 'HostMemoryTracker', 'ThroughputTracker']
class GPUUtilizationTracker(Callback): class GPUUtilizationTracker(Callback):
...@@ -276,3 +278,54 @@ class HostMemoryTracker(Callback): ...@@ -276,3 +278,54 @@ class HostMemoryTracker(Callback):
def _free_ram_gb(self): def _free_ram_gb(self):
return psutil.virtual_memory().available / 1024**3 return psutil.virtual_memory().available / 1024**3
class ThroughputTracker(Callback):
"""
This callback writes the training throughput (in terms of either steps/sec, or samples/sec)
to the monitors everytime it is triggered.
The throughput is computed based on the duration between the consecutive triggers.
The time spent on callbacks after each epoch is excluded.
"""
_chief_only = False
def __init__(self, samples_per_step=None):
"""
Args:
samples_per_step (int or None): total number of samples processed in each step
(i.e., your total batch size in each step).
If not provided, this callback will record "steps/sec" instead of "samples/sec".
"""
if samples_per_step is not None:
samples_per_step = int(samples_per_step)
self._samples_per_step = samples_per_step
self._timer = Timer()
self._timer.pause()
# only include the time between before_epoch/after_epoch
def _before_epoch(self):
self._timer.resume()
def _after_epoch(self):
self._timer.pause()
def _before_train(self):
self._update_last()
def _update_last(self):
old_pause = self._timer.is_paused()
self._timer.reset()
if old_pause:
self._timer.pause()
self._last_step = self.global_step
def _trigger(self):
steps_per_sec = (self.global_step - self._last_step) / self._timer.seconds()
self._update_last()
if self._samples_per_step is None:
self.trainer.monitors.put_scalar("Throughput (steps/sec)", steps_per_sec)
else:
self.trainer.monitors.put_scalar("Throughput (samples/sec)", steps_per_sec * self._samples_per_step)
...@@ -38,7 +38,7 @@ class ILSVRCMeta(object): ...@@ -38,7 +38,7 @@ class ILSVRCMeta(object):
dict: {cls_number: cls_name} dict: {cls_number: cls_name}
""" """
fname = os.path.join(self.dir, 'synset_words.txt') fname = os.path.join(self.dir, 'synset_words.txt')
assert os.path.isfile(fname) assert os.path.isfile(fname), fname
lines = [x.strip() for x in open(fname).readlines()] lines = [x.strip() for x in open(fname).readlines()]
return dict(enumerate(lines)) return dict(enumerate(lines))
......
...@@ -16,7 +16,7 @@ if six.PY3: ...@@ -16,7 +16,7 @@ if six.PY3:
__all__ = ['total_timer', 'timed_operation', __all__ = ['total_timer', 'timed_operation',
'print_total_timer', 'IterSpeedCounter'] 'print_total_timer', 'IterSpeedCounter', 'Timer']
@contextmanager @contextmanager
...@@ -113,3 +113,48 @@ class IterSpeedCounter(object): ...@@ -113,3 +113,48 @@ class IterSpeedCounter(object):
t = timer() - self.start t = timer() - self.start
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format( logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt)) self.name, t, self.cnt, t / self.cnt))
class Timer():
"""
A timer class which computes the time elapsed since the start/reset of the timer.
"""
def __init__(self):
self.reset()
def reset(self):
"""
Reset the timer.
"""
self._start = timer()
self._paused = False
self._total_paused = 0
def pause(self):
"""
Pause the timer.
"""
assert self._paused is False
self._paused = timer()
def is_paused(self):
return self._paused is not False
def resume(self):
"""
Resume the timer.
"""
assert self._paused is not False
self._total_paused += timer() - self._paused
self._paused = False
def seconds(self):
"""
Returns:
float: the total number of seconds since the start/reset of the timer, excluding the
time in between when the timer is paused.
"""
if self._paused:
self.resume()
self.pause()
return timer() - self._start - self._total_paused
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment