Commit ccd67e86 authored by Yuxin Wu's avatar Yuxin Wu

Add PeakMemoryTracker as a callback.

parent 3dcd7d32
...@@ -16,7 +16,7 @@ from ..utils import logger ...@@ -16,7 +16,7 @@ from ..utils import logger
from ..utils.concurrency import ensure_proc_terminate, subproc_call from ..utils.concurrency import ensure_proc_terminate, subproc_call
from ..utils.gpu import get_nr_gpu from ..utils.gpu import get_nr_gpu
__all__ = ['GPUUtilizationTracker', 'GraphProfiler'] __all__ = ['GPUUtilizationTracker', 'GraphProfiler', 'PeakMemoryTracker']
class GPUUtilizationTracker(Callback): class GPUUtilizationTracker(Callback):
...@@ -164,3 +164,33 @@ class GraphProfiler(Callback): ...@@ -164,3 +164,33 @@ class GraphProfiler(Callback):
evt.tagged_run_metadata.tag = 'trace-{}'.format(self.global_step) evt.tagged_run_metadata.tag = 'trace-{}'.format(self.global_step)
evt.tagged_run_metadata.run_metadata = metadata.SerializeToString() evt.tagged_run_metadata.run_metadata = metadata.SerializeToString()
self.trainer.monitors.put_event(evt) self.trainer.monitors.put_event(evt)
class PeakMemoryTracker(Callback):
"""
Track peak memory in each session run, by
:module:`tf.contrib.memory_stats`.
It can only be used for GPUs.
"""
def __init__(self, devices=['/gpu:0']):
"""
Args:
devices([str]): list of devices to track memory on.
"""
self._devices = devices
def _setup_graph(self):
from tensorflow.contrib.memory_stats import MaxBytesInUse
ops = []
for dev in self._devices:
with tf.device(dev):
ops.append(MaxBytesInUse())
self._fetches = tf.train.SessionRunArgs(fetches=ops)
def _before_run(self, _):
return self._fetches
def _after_run(self, _, rv):
results = rv.results
for mem, dev in zip(results, self._devices):
self.trainer.monitors.put_scalar('PeakMemory(MB)' + dev, mem / 1e6)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment