Commit 9b3b5413 authored by Yuxin Wu's avatar Yuxin Wu

Speed up TFEventWriter initialization

parent dd138d5a
...@@ -51,7 +51,9 @@ class CallbackTimeLogger(object): ...@@ -51,7 +51,9 @@ class CallbackTimeLogger(object):
class Callbacks(Callback): class Callbacks(Callback):
""" """
A container to hold all callbacks, and trigger them iteratively. A container to hold all callbacks, and trigger them iteratively.
Note that it does nothing to before_run/after_run.
This is only used by the base trainer to run all the callbacks.
Users do not need to use this class.
""" """
def __init__(self, cbs): def __init__(self, cbs):
......
...@@ -12,6 +12,7 @@ import time ...@@ -12,6 +12,7 @@ import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
import six import six
import threading
from ..compat import tfv1 as tf from ..compat import tfv1 as tf
from ..libinfo import __git_version__ from ..libinfo import __git_version__
...@@ -23,7 +24,7 @@ from .base import Callback ...@@ -23,7 +24,7 @@ from .base import Callback
__all__ = ['MonitorBase', 'Monitors', __all__ = ['MonitorBase', 'Monitors',
'TFEventWriter', 'JSONWriter', 'TFEventWriter', 'JSONWriter',
'ScalarPrinter', 'SendMonitorData', 'ScalarPrinter', 'SendMonitorData',
'TrainingMonitor', 'CometMLMonitor'] 'CometMLMonitor']
def image_to_nhwc(arr): def image_to_nhwc(arr):
...@@ -53,7 +54,9 @@ class MonitorBase(Callback): ...@@ -53,7 +54,9 @@ class MonitorBase(Callback):
_chief_only = False _chief_only = False
def setup_graph(self, trainer): def setup_graph(self, trainer):
# Set attributes following Callback.setup_graph
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph()
self._setup_graph() self._setup_graph()
def _setup_graph(self): def _setup_graph(self):
...@@ -97,12 +100,6 @@ class MonitorBase(Callback): ...@@ -97,12 +100,6 @@ class MonitorBase(Callback):
# TODO process other types # TODO process other types
TrainingMonitor = MonitorBase
"""
Old name
"""
class NoOpMonitor(MonitorBase): class NoOpMonitor(MonitorBase):
def __init__(self, name=None): def __init__(self, name=None):
self._name = name self._name = name
...@@ -259,8 +256,17 @@ class TFEventWriter(MonitorBase): ...@@ -259,8 +256,17 @@ class TFEventWriter(MonitorBase):
def _setup_graph(self): def _setup_graph(self):
self._writer = tf.summary.FileWriter( self._writer = tf.summary.FileWriter(
self._logdir, graph=tf.get_default_graph(), self._logdir, max_queue=self._max_queue, flush_secs=self._flush_secs)
max_queue=self._max_queue, flush_secs=self._flush_secs)
def _write_graph(self):
self._writer.add_graph(self.graph)
def _before_train(self):
# Writing the graph is expensive (takes ~2min) when the graph is large.
# Therefore use a separate thread. It will then run in the
# background while TF is warming up in the first several iterations.
self._write_graph_thread = threading.Thread(target=self._write_graph, daemon=True)
self._write_graph_thread.start()
@HIDE_DOC @HIDE_DOC
def process_summary(self, summary): def process_summary(self, summary):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment