Commit 54e391c0 authored by Yuxin Wu's avatar Yuxin Wu

Write profiling results to tensorboard as well. (fix #309)

parent 7f4ca5f9
......@@ -6,6 +6,7 @@
import os
import numpy as np
import shutil
import time
import operator
from collections import defaultdict
import six
......@@ -142,10 +143,13 @@ class Monitors(TrainingMonitor):
def put_event(self, evt):
"""
Simply call :meth:`put_event` on each monitor.
`step` and `wall_time` fields of this proto will be filled automatically.
Args:
evt (tf.Event):
"""
evt.step = self.global_step
evt.wall_time = time.time()
self._dispatch(lambda m: m.put_event(evt))
def get_latest(self, name):
......
......@@ -89,7 +89,7 @@ class GPUUtilizationTracker(Callback):
break
# TODO add more features from tfprof
# Can add more features from tfprof
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/tfprof/g3doc/python_api.md
class GraphProfiler(Callback):
......@@ -106,15 +106,17 @@ class GraphProfiler(Callback):
You probably want to schedule it less frequently by
:class:`PeriodicRunHooks`.
"""
def __init__(self, dump_metadata=False, dump_tracing=True):
def __init__(self, dump_metadata=False, dump_tracing=True, dump_event=False):
"""
Args:
dump_metadata(bool): Dump :class:`tf.RunMetadata` to be used with tfprof.
dump_tracing(bool): Dump chrome tracing files.
dump_event(bool): Dump to an event processed by FileWriter.
"""
self._dir = logger.LOG_DIR
self._dump_meta = bool(dump_metadata)
self._dump_tracing = bool(dump_tracing)
self._dump_event = bool(dump_event)
assert os.path.isdir(self._dir)
def _before_run(self, _):
......@@ -128,6 +130,8 @@ class GraphProfiler(Callback):
self._write_meta(meta)
if self._dump_tracing:
self._write_tracing(meta)
if self._dump_event:
self._write_event(meta)
def _write_meta(self, metadata):
fname = os.path.join(
......@@ -142,3 +146,9 @@ class GraphProfiler(Callback):
with open(fname, 'w') as f:
f.write(tl.generate_chrome_trace_format(
show_dataflow=True, show_memory=True))
def _write_event(self, metadata):
evt = tf.Event()
evt.tagged_run_metadata.tag = 'trace-{}'.format(self.global_step)
evt.tagged_run_metadata.run_metadata = metadata.SerializeToString()
self.trainer.monitors.put_event(evt)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment