Commit 54e391c0 authored by Yuxin Wu's avatar Yuxin Wu

Write profiling results to tensorboard as well. (fix #309)

parent 7f4ca5f9
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import os import os
import numpy as np import numpy as np
import shutil import shutil
import time
import operator import operator
from collections import defaultdict from collections import defaultdict
import six import six
...@@ -142,10 +143,13 @@ class Monitors(TrainingMonitor): ...@@ -142,10 +143,13 @@ class Monitors(TrainingMonitor):
def put_event(self, evt): def put_event(self, evt):
""" """
Simply call :meth:`put_event` on each monitor. Simply call :meth:`put_event` on each monitor.
`step` and `wall_time` fields of this proto will be filled automatically.
Args: Args:
evt (tf.Event): evt (tf.Event):
""" """
evt.step = self.global_step
evt.wall_time = time.time()
self._dispatch(lambda m: m.put_event(evt)) self._dispatch(lambda m: m.put_event(evt))
def get_latest(self, name): def get_latest(self, name):
......
...@@ -89,7 +89,7 @@ class GPUUtilizationTracker(Callback): ...@@ -89,7 +89,7 @@ class GPUUtilizationTracker(Callback):
break break
# TODO add more features from tfprof # Can add more features from tfprof
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/tfprof/g3doc/python_api.md # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/tfprof/g3doc/python_api.md
class GraphProfiler(Callback): class GraphProfiler(Callback):
...@@ -106,15 +106,17 @@ class GraphProfiler(Callback): ...@@ -106,15 +106,17 @@ class GraphProfiler(Callback):
You probably want to schedule it less frequently by You probably want to schedule it less frequently by
:class:`PeriodicRunHooks`. :class:`PeriodicRunHooks`.
""" """
def __init__(self, dump_metadata=False, dump_tracing=True): def __init__(self, dump_metadata=False, dump_tracing=True, dump_event=False):
""" """
Args: Args:
dump_metadata(bool): Dump :class:`tf.RunMetadata` to be used with tfprof. dump_metadata(bool): Dump :class:`tf.RunMetadata` to be used with tfprof.
dump_tracing(bool): Dump chrome tracing files. dump_tracing(bool): Dump chrome tracing files.
dump_event(bool): Dump to an event processed by FileWriter.
""" """
self._dir = logger.LOG_DIR self._dir = logger.LOG_DIR
self._dump_meta = bool(dump_metadata) self._dump_meta = bool(dump_metadata)
self._dump_tracing = bool(dump_tracing) self._dump_tracing = bool(dump_tracing)
self._dump_event = bool(dump_event)
assert os.path.isdir(self._dir) assert os.path.isdir(self._dir)
def _before_run(self, _): def _before_run(self, _):
...@@ -128,6 +130,8 @@ class GraphProfiler(Callback): ...@@ -128,6 +130,8 @@ class GraphProfiler(Callback):
self._write_meta(meta) self._write_meta(meta)
if self._dump_tracing: if self._dump_tracing:
self._write_tracing(meta) self._write_tracing(meta)
if self._dump_event:
self._write_event(meta)
def _write_meta(self, metadata): def _write_meta(self, metadata):
fname = os.path.join( fname = os.path.join(
...@@ -142,3 +146,9 @@ class GraphProfiler(Callback): ...@@ -142,3 +146,9 @@ class GraphProfiler(Callback):
with open(fname, 'w') as f: with open(fname, 'w') as f:
f.write(tl.generate_chrome_trace_format( f.write(tl.generate_chrome_trace_format(
show_dataflow=True, show_memory=True)) show_dataflow=True, show_memory=True))
def _write_event(self, metadata):
evt = tf.Event()
evt.tagged_run_metadata.tag = 'trace-{}'.format(self.global_step)
evt.tagged_run_metadata.run_metadata = metadata.SerializeToString()
self.trainer.monitors.put_event(evt)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment