Commit 9b3b5413 authored by Yuxin Wu's avatar Yuxin Wu

Speed up TFEventWriter initialization

parent dd138d5a
......@@ -51,7 +51,9 @@ class CallbackTimeLogger(object):
class Callbacks(Callback):
"""
A container to hold all callbacks, and trigger them iteratively.
Note that it does nothing to before_run/after_run.
This is only used by the base trainer to run all the callbacks.
Users do not need to use this class.
"""
def __init__(self, cbs):
......
......@@ -12,6 +12,7 @@ import time
from collections import defaultdict
from datetime import datetime
import six
import threading
from ..compat import tfv1 as tf
from ..libinfo import __git_version__
......@@ -23,7 +24,7 @@ from .base import Callback
__all__ = ['MonitorBase', 'Monitors',
'TFEventWriter', 'JSONWriter',
'ScalarPrinter', 'SendMonitorData',
'TrainingMonitor', 'CometMLMonitor']
'CometMLMonitor']
def image_to_nhwc(arr):
......@@ -53,7 +54,9 @@ class MonitorBase(Callback):
_chief_only = False
def setup_graph(self, trainer):
# Set attributes following Callback.setup_graph
self.trainer = trainer
self.graph = tf.get_default_graph()
self._setup_graph()
def _setup_graph(self):
......@@ -97,12 +100,6 @@ class MonitorBase(Callback):
# TODO process other types
TrainingMonitor = MonitorBase
"""
Old name
"""
class NoOpMonitor(MonitorBase):
def __init__(self, name=None):
self._name = name
......@@ -259,8 +256,17 @@ class TFEventWriter(MonitorBase):
def _setup_graph(self):
self._writer = tf.summary.FileWriter(
self._logdir, graph=tf.get_default_graph(),
max_queue=self._max_queue, flush_secs=self._flush_secs)
self._logdir, max_queue=self._max_queue, flush_secs=self._flush_secs)
def _write_graph(self):
self._writer.add_graph(self.graph)
def _before_train(self):
# Writing the graph is expensive (takes ~2min) when the graph is large.
# Therefore use a separate thread. It will then run in the
# background while TF is warming up in the first several iterations.
self._write_graph_thread = threading.Thread(target=self._write_graph, daemon=True)
self._write_graph_thread.start()
@HIDE_DOC
def process_summary(self, summary):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment