Commit c51ce295 authored by Yuxin Wu's avatar Yuxin Wu

Add HostMemoryTracker

parent 31cfcadf
......@@ -31,7 +31,7 @@ MOCK_MODULES = ['tabulate', 'h5py',
'scipy', 'scipy.misc', 'scipy.io',
'tornado', 'tornado.concurrent',
'horovod', 'horovod.tensorflow',
'subprocess32', 'functools32']
'subprocess32', 'functools32', 'psutil']
# it's better to have tensorflow installed (for some docs to show)
# but it's OK to mock it as well
......@@ -385,6 +385,7 @@ _DEPRECATED_NAMES = set([
'start_test', # TestDataSpeed
'ThreadedMapData',
'TrainingMonitor',
'PeakMemoryTracker',
# deprecated or renamed symbolic code
'Deconv2D', 'psnr',
......
......@@ -485,9 +485,10 @@ if __name__ == '__main__':
ScheduledHyperParamSetter(
'learning_rate', warmup_schedule, interp='linear', step_based=True),
ScheduledHyperParamSetter('learning_rate', lr_schedule),
PeakMemoryTracker(),
GPUMemoryTracker(),
HostMemoryTracker(),
EstimatedTimeLeft(median=True),
SessionRunTimeout(60000).set_chief_only(True), # 1 minute timeout
SessionRunTimeout(60000), # 1 minute timeout
]
if cfg.TRAIN.EVAL_PERIOD > 0:
callbacks.extend([
......
......@@ -47,7 +47,7 @@ setup(
"six",
"termcolor>=1.1",
"tabulate>=0.7.7",
"tqdm>4.11.1",
"tqdm>4.29.0",
"msgpack>=0.5.2",
"msgpack-numpy>=0.4.4.2",
"pyzmq>=16",
......
......@@ -46,7 +46,7 @@ class InjectShell(Callback):
callbacks=[InjectShell('/path/to/pause-training.tmp'), ...]
# the following command will pause the training when the epoch finishes:
# the following command will pause the training and start a shell when the epoch finishes:
$ touch /path/to/pause-training.tmp
"""
......@@ -85,11 +85,11 @@ class EstimatedTimeLeft(Callback):
"""
Estimate the time left until completion of training.
"""
def __init__(self, last_k_epochs=5, median=False):
def __init__(self, last_k_epochs=5, median=True):
"""
Args:
last_k_epochs (int): Use the time spent on last k epochs to estimate total time left.
median (bool): Use mean by default. If True, use the median time spent on last k epochs.
median (bool): Use the mean or median time spent on last k epochs.
"""
self._times = deque(maxlen=last_k_epochs)
self._median = median
......
......@@ -9,6 +9,7 @@ import time
import tensorflow as tf
from six.moves import map, queue
from tensorflow.python.client import timeline
import psutil
from ..tfutils.common import gpu_available_in_session
from ..utils import logger
......@@ -17,7 +18,7 @@ from ..utils.gpu import get_num_gpu
from ..utils.nvml import NVMLContext
from .base import Callback
__all__ = ['GPUUtilizationTracker', 'GraphProfiler', 'PeakMemoryTracker']
__all__ = ['GPUUtilizationTracker', 'GraphProfiler', 'PeakMemoryTracker', 'GPUMemoryTracker', 'HostMemoryTracker']
class GPUUtilizationTracker(Callback):
......@@ -205,11 +206,11 @@ class GraphProfiler(Callback):
self.trainer.monitors.put_event(evt)
class PeakMemoryTracker(Callback):
class GPUMemoryTracker(Callback):
"""
Track peak memory used on each GPU device every epoch, by :mod:`tf.contrib.memory_stats`.
The peak memory comes from the `MaxBytesInUse` op, which might span
multiple session.run.
The peak memory comes from the ``MaxBytesInUse`` op, which is the peak memory used
in recent ``session.run`` calls.
See https://github.com/tensorflow/tensorflow/pull/13107.
"""
......@@ -245,3 +246,28 @@ class PeakMemoryTracker(Callback):
if results is not None:
for mem, dev in zip(results, self._devices):
self.trainer.monitors.put_scalar('PeakMemory(MB)' + dev, mem / 1e6)
PeakMemoryTracker = GPUMemoryTracker
class HostMemoryTracker(Callback):
"""
Track free RAM on the host.
When triggered, it writes the size of free RAM into monitors.
"""
_chief_only = False
def _setup_graph(self):
logger.info("[HostMemoryTracker] Free RAM in setup_graph() is {:.2f} GB.".format(self._free_ram_gb()))
def _before_train(self):
logger.info("[HostMemoryTracker] Free RAM in before_train() is {:.2f} GB.".format(self._free_ram_gb()))
def _trigger(self):
ram_gb = self._free_ram_gb()
self.trainer.monitors.put_scalar('HostFreeMemory (GB)', ram_gb)
def _free_ram_gb(self):
return psutil.virtual_memory().available / 1024**3
......@@ -140,7 +140,7 @@ class SessionRunTimeout(Callback):
"""
Args:
timeout_in_ms (int):
"""
"""
self._timeout = int(timeout_in_ms)
opt = tf.RunOptions(timeout_in_ms=timeout_in_ms)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment