Commit 7f4ca5f9 authored by Yuxin Wu's avatar Yuxin Wu

add `put_event` to monitor

parent 57619f3a
...@@ -56,8 +56,8 @@ TrainConfig( ...@@ -56,8 +56,8 @@ TrainConfig(
RunUpdateOps(), RunUpdateOps(),
], ],
monitors=[ # monitors are a special kind of callbacks. these are also enabled by default monitors=[ # monitors are a special kind of callbacks. these are also enabled by default
# write all monitor data to tensorboard # write everything to tensorboard
TFSummaryWriter(), TFEventWriter(),
# write all scalar data to a json file, for easy parsing # write all scalar data to a json file, for easy parsing
JSONWriter(), JSONWriter(),
# print all scalar data every epoch (can be configured differently) # print all scalar data every epoch (can be configured differently)
......
...@@ -110,8 +110,8 @@ class Evaluator(Triggerable): ...@@ -110,8 +110,8 @@ class Evaluator(Triggerable):
t = time.time() - t t = time.time() - t
if t > 10 * 60: # eval takes too long if t > 10 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.94) self.eval_episode = int(self.eval_episode * 0.94)
self.trainer.monitors.put('mean_score', mean) self.trainer.monitors.put_scalar('mean_score', mean)
self.trainer.monitors.put('max_score', max) self.trainer.monitors.put_scalar('max_score', max)
def play_n_episodes(player, predfunc, nr): def play_n_episodes(player, predfunc, nr):
......
...@@ -147,10 +147,10 @@ def get_config(): ...@@ -147,10 +147,10 @@ def get_config():
RunOp(lambda: M.reset_lstm_state()), RunOp(lambda: M.reset_lstm_state()),
CallbackFactory( CallbackFactory(
trigger_epoch=lambda self: trigger_epoch=lambda self:
[self.trainer.monitors.put( [self.trainer.monitors.put_scalar(
'validation_perplexity', 'validation_perplexity',
np.exp(self.trainer.monitors.get_latest('validation_cost') / SEQ_LEN)), np.exp(self.trainer.monitors.get_latest('validation_cost') / SEQ_LEN)),
self.trainer.monitors.put( self.trainer.monitors.put_scalar(
'test_perplexity', 'test_perplexity',
np.exp(self.trainer.monitors.get_latest('test_cost') / SEQ_LEN))] np.exp(self.trainer.monitors.get_latest('test_cost') / SEQ_LEN))]
), ),
......
...@@ -28,12 +28,14 @@ class Callback(object): ...@@ -28,12 +28,14 @@ class Callback(object):
.. document private functions .. document private functions
.. automethod:: _setup_graph .. automethod:: _setup_graph
.. automethod:: _before_train .. automethod:: _before_train
.. automethod:: _after_train
.. automethod:: _before_run .. automethod:: _before_run
.. automethod:: _after_run .. automethod:: _after_run
.. automethod:: _before_epoch
.. automethod:: _after_epoch
.. automethod:: _trigger_step .. automethod:: _trigger_step
.. automethod:: _trigger_epoch .. automethod:: _trigger_epoch
.. automethod:: _trigger .. automethod:: _trigger
.. automethod:: _after_train
""" """
_chief_only = True _chief_only = True
......
...@@ -18,7 +18,7 @@ from ..tfutils.summary import create_scalar_summary, create_image_summary ...@@ -18,7 +18,7 @@ from ..tfutils.summary import create_scalar_summary, create_image_summary
from .base import Callback from .base import Callback
__all__ = ['TrainingMonitor', 'Monitors', __all__ = ['TrainingMonitor', 'Monitors',
'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter', 'SendMonitorData'] 'TFSummaryWriter', 'TFEventWriter', 'JSONWriter', 'ScalarPrinter', 'SendMonitorData']
def image_to_nhwc(arr): def image_to_nhwc(arr):
...@@ -75,6 +75,13 @@ class TrainingMonitor(Callback): ...@@ -75,6 +75,13 @@ class TrainingMonitor(Callback):
""" """
pass pass
def put_event(self, evt):
"""
Args:
evt (tf.Event): the most basic format, could include Summary,
RunMetadata, LogMessage, and more.
"""
pass
# TODO put other types # TODO put other types
...@@ -93,17 +100,9 @@ class Monitors(TrainingMonitor): ...@@ -93,17 +100,9 @@ class Monitors(TrainingMonitor):
def _setup_graph(self): def _setup_graph(self):
self._scalar_history.setup_graph(self.trainer) self._scalar_history.setup_graph(self.trainer)
def _dispatch_put_summary(self, summary): def _dispatch(self, func):
for m in self._monitors: for m in self._monitors:
m.put_summary(summary) func(m)
def _dispatch_put_scalar(self, name, val):
for m in self._monitors:
m.put_scalar(name, val)
def _dispatch_put_image(self, name, val):
for m in self._monitors:
m.put_image(name, val)
def put_summary(self, summary): def put_summary(self, summary):
if isinstance(summary, six.binary_type): if isinstance(summary, six.binary_type):
...@@ -111,7 +110,7 @@ class Monitors(TrainingMonitor): ...@@ -111,7 +110,7 @@ class Monitors(TrainingMonitor):
assert isinstance(summary, tf.Summary), type(summary) assert isinstance(summary, tf.Summary), type(summary)
# TODO remove -summary suffix for summary # TODO remove -summary suffix for summary
self._dispatch_put_summary(summary) self._dispatch(lambda m: m.put_summary(summary))
# TODO other types # TODO other types
for val in summary.value: for val in summary.value:
...@@ -120,16 +119,12 @@ class Monitors(TrainingMonitor): ...@@ -120,16 +119,12 @@ class Monitors(TrainingMonitor):
suffix = '-summary' # issue#6150 suffix = '-summary' # issue#6150
if val.tag.endswith(suffix): if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)] val.tag = val.tag[:-len(suffix)]
self._dispatch_put_scalar(val.tag, val.simple_value) self._dispatch(lambda m: m.put_scalar(val.tag, val.simple_value))
def put(self, name, val):
val = float(val) # TODO only support scalar for now
self.put_scalar(name, val)
def put_scalar(self, name, val): def put_scalar(self, name, val):
self._dispatch_put_scalar(name, val) self._dispatch(lambda m: m.put_scalar(name, val))
s = create_scalar_summary(name, val) s = create_scalar_summary(name, val)
self._dispatch_put_summary(s) self._dispatch(lambda m: m.put_summary(s))
def put_image(self, name, val): def put_image(self, name, val):
""" """
...@@ -140,9 +135,18 @@ class Monitors(TrainingMonitor): ...@@ -140,9 +135,18 @@ class Monitors(TrainingMonitor):
""" """
assert isinstance(val, np.ndarray) assert isinstance(val, np.ndarray)
arr = image_to_nhwc(val) arr = image_to_nhwc(val)
self._dispatch_put_image(name, arr) self._dispatch(lambda m: m.put_image(name, arr))
s = create_image_summary(name, arr) s = create_image_summary(name, arr)
self._dispatch_put_summary(s) self._dispatch(lambda m: m.put_summary(s))
def put_event(self, evt):
"""
Simply call :meth:`put_event` on each monitor.
Args:
evt (tf.Event):
"""
self._dispatch(lambda m: m.put_event(evt))
def get_latest(self, name): def get_latest(self, name):
""" """
...@@ -157,15 +161,15 @@ class Monitors(TrainingMonitor): ...@@ -157,15 +161,15 @@ class Monitors(TrainingMonitor):
return self._scalar_history.get_history(name) return self._scalar_history.get_history(name)
class TFSummaryWriter(TrainingMonitor): class TFEventWriter(TrainingMonitor):
""" """
Write summaries to TensorFlow event file. Write summaries to TensorFlow event file.
""" """
def __new__(cls): def __new__(cls):
if logger.LOG_DIR: if logger.LOG_DIR:
return super(TFSummaryWriter, cls).__new__(cls) return super(TFEventWriter, cls).__new__(cls)
else: else:
logger.warn("logger directory was not set. Ignore TFSummaryWriter.") logger.warn("logger directory was not set. Ignore TFEventWriter.")
return NoOpMonitor() return NoOpMonitor()
def _setup_graph(self): def _setup_graph(self):
...@@ -174,6 +178,9 @@ class TFSummaryWriter(TrainingMonitor): ...@@ -174,6 +178,9 @@ class TFSummaryWriter(TrainingMonitor):
def put_summary(self, summary): def put_summary(self, summary):
self._writer.add_summary(summary, self.global_step) self._writer.add_summary(summary, self.global_step)
def put_event(self, evt):
self._writer.add_event(evt)
def _trigger(self): # flush every epoch def _trigger(self): # flush every epoch
self._writer.flush() self._writer.flush()
...@@ -181,6 +188,11 @@ class TFSummaryWriter(TrainingMonitor): ...@@ -181,6 +188,11 @@ class TFSummaryWriter(TrainingMonitor):
self._writer.close() self._writer.close()
def TFSummaryWriter(*args, **kwargs):
logger.warn("TFSummaryWriter was renamed to TFEventWriter!")
return TFEventWriter(*args, **kwargs)
class JSONWriter(TrainingMonitor): class JSONWriter(TrainingMonitor):
""" """
Write all scalar data to a json, grouped by their global step. Write all scalar data to a json, grouped by their global step.
......
...@@ -35,7 +35,7 @@ class TowerContext(object): ...@@ -35,7 +35,7 @@ class TowerContext(object):
is_training = not self._name.startswith(PREDICT_TOWER) is_training = not self._name.startswith(PREDICT_TOWER)
self._is_training = bool(is_training) self._is_training = bool(is_training)
self._index = index self._index = int(index)
assert var_strategy in ['replicated', 'shared'], var_strategy assert var_strategy in ['replicated', 'shared'], var_strategy
self._var_strategy = var_strategy self._var_strategy = var_strategy
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from ..callbacks import ( from ..callbacks import (
Callbacks, MovingAverageSummary, Callbacks, MovingAverageSummary,
ProgressBar, MergeAllSummaries, ProgressBar, MergeAllSummaries,
TFSummaryWriter, JSONWriter, ScalarPrinter, RunUpdateOps) TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
...@@ -44,7 +44,7 @@ class TrainConfig(object): ...@@ -44,7 +44,7 @@ class TrainConfig(object):
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), RunUpdateOps()]``. The list of ``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), RunUpdateOps()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``. callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`. monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``. Defaults to ``[TFEventWriter(), JSONWriter(), ScalarPrinter()]``.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()` session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`. with the config returned by :func:`tfutils.get_default_sess_config()`.
session_config (tf.ConfigProto): when session_creator is None, use this to create the session. session_config (tf.ConfigProto): when session_creator is None, use this to create the session.
...@@ -92,7 +92,7 @@ class TrainConfig(object): ...@@ -92,7 +92,7 @@ class TrainConfig(object):
assert_type(self._callbacks, list) assert_type(self._callbacks, list)
if monitors is None: if monitors is None:
monitors = [TFSummaryWriter(), JSONWriter(), ScalarPrinter()] monitors = [TFEventWriter(), JSONWriter(), ScalarPrinter()]
self.monitors = monitors self.monitors = monitors
self.model = model self.model = model
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment