Commit 35beb43c authored by Yuxin Wu's avatar Yuxin Wu

add comet.ml monitor

parent a4d4eafc
......@@ -380,13 +380,14 @@ _DEPRECATED_NAMES = set([
'dump_dataflow_to_process_queue',
'PrefetchOnGPUs',
# renamed stuff:
# renamed items that should not appear in docs
'DumpTensor',
'DumpParamAsImage',
'PeriodicRunHooks',
'get_nr_gpu',
'start_test', # TestDataSpeed
'ThreadedMapData',
'TrainingMonitor',
# deprecated or renamed symbolic code
'BilinearUpSample',
......
......@@ -14,14 +14,16 @@ from datetime import datetime
import six
import tensorflow as tf
from ..libinfo import __git_version__
from ..tfutils.summary import create_image_summary, create_scalar_summary
from ..utils import logger
from ..utils.develop import HIDE_DOC
from .base import Callback
__all__ = ['TrainingMonitor', 'Monitors',
__all__ = ['MonitorBase', 'Monitors',
'TFEventWriter', 'JSONWriter',
'ScalarPrinter', 'SendMonitorData']
'ScalarPrinter', 'SendMonitorData',
'TrainingMonitor', 'CometMLMonitor']
def image_to_nhwc(arr):
......@@ -39,9 +41,9 @@ def image_to_nhwc(arr):
return arr
class TrainingMonitor(Callback):
class MonitorBase(Callback):
"""
Monitor a training progress, by processing different types of
Base class for monitors which monitor a training progress, by processing different types of
summary/statistics from trainer.
.. document private functions
......@@ -95,7 +97,13 @@ class TrainingMonitor(Callback):
# TODO process other types
class NoOpMonitor(TrainingMonitor):
TrainingMonitor = MonitorBase
"""
Old name
"""
class NoOpMonitor(MonitorBase):
def __init__(self, name=None):
self._name = name
......@@ -121,7 +129,7 @@ class Monitors(Callback):
self._scalar_history = ScalarHistory()
self._monitors = monitors + [self._scalar_history]
for m in self._monitors:
assert isinstance(m, TrainingMonitor), m
assert isinstance(m, MonitorBase), m
def _setup_graph(self):
# scalar_history's other methods were not called.
......@@ -219,7 +227,7 @@ class Monitors(Callback):
return self._scalar_history.get_history(name)
class TFEventWriter(TrainingMonitor):
class TFEventWriter(MonitorBase):
"""
Write summaries to TensorFlow event file.
"""
......@@ -272,7 +280,7 @@ class TFEventWriter(TrainingMonitor):
self._writer.close()
class JSONWriter(TrainingMonitor):
class JSONWriter(MonitorBase):
"""
Write all scalar data to a json file under ``logger.get_logger_dir()``, grouped by their global step.
If found an earlier json history file, will append to it.
......@@ -390,7 +398,7 @@ class JSONWriter(TrainingMonitor):
logger.exception("Exception in JSONWriter._write_stat()!")
class ScalarPrinter(TrainingMonitor):
class ScalarPrinter(MonitorBase):
"""
Print scalar data into terminal.
"""
......@@ -460,7 +468,7 @@ class ScalarPrinter(TrainingMonitor):
self._dic = {}
class ScalarHistory(TrainingMonitor):
class ScalarHistory(MonitorBase):
"""
Only internally used by monitors.
"""
......@@ -483,7 +491,7 @@ class ScalarHistory(TrainingMonitor):
return self._dic[name]
class SendMonitorData(TrainingMonitor):
class SendMonitorData(MonitorBase):
"""
Execute a command with some specific scalar monitor data.
This is useful for, e.g. building a custom statistics monitor.
......@@ -531,3 +539,54 @@ class SendMonitorData(TrainingMonitor):
if ret != 0:
logger.error("Command '{}' failed with ret={}!".format(cmd, ret))
self.dic = {}
class CometMLMonitor(MonitorBase):
"""
Send data to https://www.comet.ml.
Note:
1. comet_ml requires you to `import comet_ml` before importing tensorflow or tensorpack.
2. The "automatic output logging" feature will make the training progress bar appear to freeze.
Therefore the feature is disabled by default.
"""
def __init__(self, experiment=None, api_key=None, tags=None, **kwargs):
"""
Args:
experiment (comet_ml.Experiment): if provided, invalidate all other arguments
api_key (str): your comet.ml API key
tags (list[str]): experiment tags
kwargs: other arguments passed to :class:`comet_ml.Experiment`.
"""
if experiment is not None:
self._exp = experiment
assert api_key is None and tags is None and len(kwargs) == 0
else:
from comet_ml import Experiment
kwargs.setdefault('log_code', True) # though it's not functioning, git patch logging requires it
kwargs.setdefault('auto_output_logging', None)
self._exp = Experiment(api_key=api_key, **kwargs)
if tags is not None:
self._exp.add_tags(tags)
self._exp.set_code("Code logging is impossible because there are too many files ...")
self._exp.log_dependency('tensorpack', __git_version__)
@property
def experiment(self):
"""
Returns: the :class:`comet_ml.Experiment` instance.
"""
return self._exp
def _before_train(self):
self._exp.set_model_graph(tf.get_default_graph())
def process_scalar(self, name, val):
self._exp.log_metric(name, val, step=self.global_step)
def _after_train(self):
self._exp.end()
def _after_epoch(self):
self._exp.log_epoch_end(self.epoch_num)
......@@ -8,7 +8,7 @@ import six
import tensorflow as tf
from six.moves import range
from ..callbacks import Callback, Callbacks, Monitors, TrainingMonitor
from ..callbacks import Callback, Callbacks, Monitors, MonitorBase
from ..callbacks.steps import MaintainStepCounter
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars
......@@ -186,7 +186,7 @@ class Trainer(object):
Args:
callbacks ([Callback]):
monitors ([TrainingMonitor]):
monitors ([MonitorBase]):
"""
assert isinstance(callbacks, list), callbacks
assert isinstance(monitors, list), monitors
......@@ -196,7 +196,7 @@ class Trainer(object):
for cb in callbacks:
self.register_callback(cb)
for cb in self._callbacks:
assert not isinstance(cb, TrainingMonitor), "Monitor cannot be pre-registered for now!"
assert not isinstance(cb, MonitorBase), "Monitor cannot be pre-registered for now!"
registered_monitors = []
for m in monitors:
if self.register_callback(m):
......
......@@ -84,7 +84,7 @@ class TrainConfig(object):
MergeAllSummaries(),
RunUpdateOps()]
monitors (list[TrainingMonitor]): Defaults to :func:`DEFAULT_MONITORS()`.
monitors (list[MonitorBase]): Defaults to :func:`DEFAULT_MONITORS()`.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment