Commit 7f4ca5f9 authored by Yuxin Wu's avatar Yuxin Wu

add `put_event` to monitor

parent 57619f3a
......@@ -56,8 +56,8 @@ TrainConfig(
RunUpdateOps(),
],
monitors=[ # monitors are a special kind of callbacks. these are also enabled by default
# write all monitor data to tensorboard
TFSummaryWriter(),
# write everything to tensorboard
TFEventWriter(),
# write all scalar data to a json file, for easy parsing
JSONWriter(),
# print all scalar data every epoch (can be configured differently)
......
......@@ -110,8 +110,8 @@ class Evaluator(Triggerable):
t = time.time() - t
if t > 10 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.94)
self.trainer.monitors.put('mean_score', mean)
self.trainer.monitors.put('max_score', max)
self.trainer.monitors.put_scalar('mean_score', mean)
self.trainer.monitors.put_scalar('max_score', max)
def play_n_episodes(player, predfunc, nr):
......
......@@ -147,10 +147,10 @@ def get_config():
RunOp(lambda: M.reset_lstm_state()),
CallbackFactory(
trigger_epoch=lambda self:
[self.trainer.monitors.put(
[self.trainer.monitors.put_scalar(
'validation_perplexity',
np.exp(self.trainer.monitors.get_latest('validation_cost') / SEQ_LEN)),
self.trainer.monitors.put(
self.trainer.monitors.put_scalar(
'test_perplexity',
np.exp(self.trainer.monitors.get_latest('test_cost') / SEQ_LEN))]
),
......
......@@ -28,12 +28,14 @@ class Callback(object):
.. document private functions
.. automethod:: _setup_graph
.. automethod:: _before_train
.. automethod:: _after_train
.. automethod:: _before_run
.. automethod:: _after_run
.. automethod:: _before_epoch
.. automethod:: _after_epoch
.. automethod:: _trigger_step
.. automethod:: _trigger_epoch
.. automethod:: _trigger
.. automethod:: _after_train
"""
_chief_only = True
......
......@@ -18,7 +18,7 @@ from ..tfutils.summary import create_scalar_summary, create_image_summary
from .base import Callback
__all__ = ['TrainingMonitor', 'Monitors',
'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter', 'SendMonitorData']
'TFSummaryWriter', 'TFEventWriter', 'JSONWriter', 'ScalarPrinter', 'SendMonitorData']
def image_to_nhwc(arr):
......@@ -75,6 +75,13 @@ class TrainingMonitor(Callback):
"""
pass
def put_event(self, evt):
"""
Args:
evt (tf.Event): the most basic format, could include Summary,
RunMetadata, LogMessage, and more.
"""
pass
# TODO put other types
......@@ -93,17 +100,9 @@ class Monitors(TrainingMonitor):
def _setup_graph(self):
self._scalar_history.setup_graph(self.trainer)
def _dispatch_put_summary(self, summary):
def _dispatch(self, func):
for m in self._monitors:
m.put_summary(summary)
def _dispatch_put_scalar(self, name, val):
for m in self._monitors:
m.put_scalar(name, val)
def _dispatch_put_image(self, name, val):
for m in self._monitors:
m.put_image(name, val)
func(m)
def put_summary(self, summary):
if isinstance(summary, six.binary_type):
......@@ -111,7 +110,7 @@ class Monitors(TrainingMonitor):
assert isinstance(summary, tf.Summary), type(summary)
# TODO remove -summary suffix for summary
self._dispatch_put_summary(summary)
self._dispatch(lambda m: m.put_summary(summary))
# TODO other types
for val in summary.value:
......@@ -120,16 +119,12 @@ class Monitors(TrainingMonitor):
suffix = '-summary' # issue#6150
if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)]
self._dispatch_put_scalar(val.tag, val.simple_value)
def put(self, name, val):
val = float(val) # TODO only support scalar for now
self.put_scalar(name, val)
self._dispatch(lambda m: m.put_scalar(val.tag, val.simple_value))
def put_scalar(self, name, val):
self._dispatch_put_scalar(name, val)
self._dispatch(lambda m: m.put_scalar(name, val))
s = create_scalar_summary(name, val)
self._dispatch_put_summary(s)
self._dispatch(lambda m: m.put_summary(s))
def put_image(self, name, val):
"""
......@@ -140,9 +135,18 @@ class Monitors(TrainingMonitor):
"""
assert isinstance(val, np.ndarray)
arr = image_to_nhwc(val)
self._dispatch_put_image(name, arr)
self._dispatch(lambda m: m.put_image(name, arr))
s = create_image_summary(name, arr)
self._dispatch_put_summary(s)
self._dispatch(lambda m: m.put_summary(s))
def put_event(self, evt):
"""
Simply call :meth:`put_event` on each monitor.
Args:
evt (tf.Event):
"""
self._dispatch(lambda m: m.put_event(evt))
def get_latest(self, name):
"""
......@@ -157,15 +161,15 @@ class Monitors(TrainingMonitor):
return self._scalar_history.get_history(name)
class TFSummaryWriter(TrainingMonitor):
class TFEventWriter(TrainingMonitor):
"""
Write summaries to TensorFlow event file.
"""
def __new__(cls):
if logger.LOG_DIR:
return super(TFSummaryWriter, cls).__new__(cls)
return super(TFEventWriter, cls).__new__(cls)
else:
logger.warn("logger directory was not set. Ignore TFSummaryWriter.")
logger.warn("logger directory was not set. Ignore TFEventWriter.")
return NoOpMonitor()
def _setup_graph(self):
......@@ -174,6 +178,9 @@ class TFSummaryWriter(TrainingMonitor):
def put_summary(self, summary):
self._writer.add_summary(summary, self.global_step)
def put_event(self, evt):
self._writer.add_event(evt)
def _trigger(self): # flush every epoch
self._writer.flush()
......@@ -181,6 +188,11 @@ class TFSummaryWriter(TrainingMonitor):
self._writer.close()
def TFSummaryWriter(*args, **kwargs):
logger.warn("TFSummaryWriter was renamed to TFEventWriter!")
return TFEventWriter(*args, **kwargs)
class JSONWriter(TrainingMonitor):
"""
Write all scalar data to a json, grouped by their global step.
......
......@@ -35,7 +35,7 @@ class TowerContext(object):
is_training = not self._name.startswith(PREDICT_TOWER)
self._is_training = bool(is_training)
self._index = index
self._index = int(index)
assert var_strategy in ['replicated', 'shared'], var_strategy
self._var_strategy = var_strategy
......
......@@ -5,7 +5,7 @@
from ..callbacks import (
Callbacks, MovingAverageSummary,
ProgressBar, MergeAllSummaries,
TFSummaryWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..utils import logger
......@@ -44,7 +44,7 @@ class TrainConfig(object):
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), RunUpdateOps()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
Defaults to ``[TFEventWriter(), JSONWriter(), ScalarPrinter()]``.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`.
session_config (tf.ConfigProto): when session_creator is None, use this to create the session.
......@@ -92,7 +92,7 @@ class TrainConfig(object):
assert_type(self._callbacks, list)
if monitors is None:
monitors = [TFSummaryWriter(), JSONWriter(), ScalarPrinter()]
monitors = [TFEventWriter(), JSONWriter(), ScalarPrinter()]
self.monitors = monitors
self.model = model
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment