Commit 5467e48a authored by Yuxin Wu's avatar Yuxin Wu

Support more options on TFEventWriter (#645)

parent 54c5a42d
...@@ -196,15 +196,30 @@ class TFEventWriter(TrainingMonitor): ...@@ -196,15 +196,30 @@ class TFEventWriter(TrainingMonitor):
""" """
Write summaries to TensorFlow event file. Write summaries to TensorFlow event file.
""" """
def __new__(cls): def __init__(self, logdir, max_queue=10, flush_secs=120):
if logger.get_logger_dir(): """
return super(TFEventWriter, cls).__new__(cls) Args:
Same as in :class:`tf.summary.FileWriter`.
logdir will be ``logger.get_logger_dir()`` by default.
"""
self._logdir = logdir
self._max_queue = max_queue
self._flush_secs = flush_secs
def __new__(cls, logdir=None, max_queue=10, flush_secs=120):
if logdir is None:
logdir = logger.get_logger_dir()
if logdir is not None:
return super(TFEventWriter, cls).__new__(cls, logdir, max_queue, flush_secs)
else: else:
logger.warn("logger directory was not set. Ignore TFEventWriter.") logger.warn("logger directory was not set. Ignore TFEventWriter.")
return NoOpMonitor() return NoOpMonitor()
def _setup_graph(self): def _setup_graph(self):
self._writer = tf.summary.FileWriter(logger.get_logger_dir(), graph=tf.get_default_graph()) self._writer = tf.summary.FileWriter(
self._logdir, graph=tf.get_default_graph(),
max_queue=self._max_queue, flush_secs=self._flush_secs)
def process_summary(self, summary): def process_summary(self, summary):
self._writer.add_summary(summary, self.global_step) self._writer.add_summary(summary, self.global_step)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment