Commit 35beb43c authored by Yuxin Wu's avatar Yuxin Wu

add comet.ml monitor

parent a4d4eafc
...@@ -380,13 +380,14 @@ _DEPRECATED_NAMES = set([ ...@@ -380,13 +380,14 @@ _DEPRECATED_NAMES = set([
'dump_dataflow_to_process_queue', 'dump_dataflow_to_process_queue',
'PrefetchOnGPUs', 'PrefetchOnGPUs',
# renamed stuff: # renamed items that should not appear in docs
'DumpTensor', 'DumpTensor',
'DumpParamAsImage', 'DumpParamAsImage',
'PeriodicRunHooks', 'PeriodicRunHooks',
'get_nr_gpu', 'get_nr_gpu',
'start_test', # TestDataSpeed 'start_test', # TestDataSpeed
'ThreadedMapData', 'ThreadedMapData',
'TrainingMonitor',
# deprecated or renamed symbolic code # deprecated or renamed symbolic code
'BilinearUpSample', 'BilinearUpSample',
......
...@@ -14,14 +14,16 @@ from datetime import datetime ...@@ -14,14 +14,16 @@ from datetime import datetime
import six import six
import tensorflow as tf import tensorflow as tf
from ..libinfo import __git_version__
from ..tfutils.summary import create_image_summary, create_scalar_summary from ..tfutils.summary import create_image_summary, create_scalar_summary
from ..utils import logger from ..utils import logger
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
from .base import Callback from .base import Callback
__all__ = ['TrainingMonitor', 'Monitors', __all__ = ['MonitorBase', 'Monitors',
'TFEventWriter', 'JSONWriter', 'TFEventWriter', 'JSONWriter',
'ScalarPrinter', 'SendMonitorData'] 'ScalarPrinter', 'SendMonitorData',
'TrainingMonitor', 'CometMLMonitor']
def image_to_nhwc(arr): def image_to_nhwc(arr):
...@@ -39,9 +41,9 @@ def image_to_nhwc(arr): ...@@ -39,9 +41,9 @@ def image_to_nhwc(arr):
return arr return arr
class TrainingMonitor(Callback): class MonitorBase(Callback):
""" """
Monitor a training progress, by processing different types of Base class for monitors which monitor a training progress, by processing different types of
summary/statistics from trainer. summary/statistics from trainer.
.. document private functions .. document private functions
...@@ -95,7 +97,13 @@ class TrainingMonitor(Callback): ...@@ -95,7 +97,13 @@ class TrainingMonitor(Callback):
# TODO process other types # TODO process other types
class NoOpMonitor(TrainingMonitor): TrainingMonitor = MonitorBase
"""
Old name
"""
class NoOpMonitor(MonitorBase):
def __init__(self, name=None): def __init__(self, name=None):
self._name = name self._name = name
...@@ -121,7 +129,7 @@ class Monitors(Callback): ...@@ -121,7 +129,7 @@ class Monitors(Callback):
self._scalar_history = ScalarHistory() self._scalar_history = ScalarHistory()
self._monitors = monitors + [self._scalar_history] self._monitors = monitors + [self._scalar_history]
for m in self._monitors: for m in self._monitors:
assert isinstance(m, TrainingMonitor), m assert isinstance(m, MonitorBase), m
def _setup_graph(self): def _setup_graph(self):
# scalar_history's other methods were not called. # scalar_history's other methods were not called.
...@@ -219,7 +227,7 @@ class Monitors(Callback): ...@@ -219,7 +227,7 @@ class Monitors(Callback):
return self._scalar_history.get_history(name) return self._scalar_history.get_history(name)
class TFEventWriter(TrainingMonitor): class TFEventWriter(MonitorBase):
""" """
Write summaries to TensorFlow event file. Write summaries to TensorFlow event file.
""" """
...@@ -272,7 +280,7 @@ class TFEventWriter(TrainingMonitor): ...@@ -272,7 +280,7 @@ class TFEventWriter(TrainingMonitor):
self._writer.close() self._writer.close()
class JSONWriter(TrainingMonitor): class JSONWriter(MonitorBase):
""" """
Write all scalar data to a json file under ``logger.get_logger_dir()``, grouped by their global step. Write all scalar data to a json file under ``logger.get_logger_dir()``, grouped by their global step.
If found an earlier json history file, will append to it. If found an earlier json history file, will append to it.
...@@ -390,7 +398,7 @@ class JSONWriter(TrainingMonitor): ...@@ -390,7 +398,7 @@ class JSONWriter(TrainingMonitor):
logger.exception("Exception in JSONWriter._write_stat()!") logger.exception("Exception in JSONWriter._write_stat()!")
class ScalarPrinter(TrainingMonitor): class ScalarPrinter(MonitorBase):
""" """
Print scalar data into terminal. Print scalar data into terminal.
""" """
...@@ -460,7 +468,7 @@ class ScalarPrinter(TrainingMonitor): ...@@ -460,7 +468,7 @@ class ScalarPrinter(TrainingMonitor):
self._dic = {} self._dic = {}
class ScalarHistory(TrainingMonitor): class ScalarHistory(MonitorBase):
""" """
Only internally used by monitors. Only internally used by monitors.
""" """
...@@ -483,7 +491,7 @@ class ScalarHistory(TrainingMonitor): ...@@ -483,7 +491,7 @@ class ScalarHistory(TrainingMonitor):
return self._dic[name] return self._dic[name]
class SendMonitorData(TrainingMonitor): class SendMonitorData(MonitorBase):
""" """
Execute a command with some specific scalar monitor data. Execute a command with some specific scalar monitor data.
This is useful for, e.g. building a custom statistics monitor. This is useful for, e.g. building a custom statistics monitor.
...@@ -531,3 +539,54 @@ class SendMonitorData(TrainingMonitor): ...@@ -531,3 +539,54 @@ class SendMonitorData(TrainingMonitor):
if ret != 0: if ret != 0:
logger.error("Command '{}' failed with ret={}!".format(cmd, ret)) logger.error("Command '{}' failed with ret={}!".format(cmd, ret))
self.dic = {} self.dic = {}
class CometMLMonitor(MonitorBase):
"""
Send data to https://www.comet.ml.
Note:
1. comet_ml requires you to `import comet_ml` before importing tensorflow or tensorpack.
2. The "automatic output logging" feature will make the training progress bar appear to freeze.
Therefore the feature is disabled by default.
"""
def __init__(self, experiment=None, api_key=None, tags=None, **kwargs):
"""
Args:
experiment (comet_ml.Experiment): if provided, invalidate all other arguments
api_key (str): your comet.ml API key
tags (list[str]): experiment tags
kwargs: other arguments passed to :class:`comet_ml.Experiment`.
"""
if experiment is not None:
self._exp = experiment
assert api_key is None and tags is None and len(kwargs) == 0
else:
from comet_ml import Experiment
kwargs.setdefault('log_code', True) # though it's not functioning, git patch logging requires it
kwargs.setdefault('auto_output_logging', None)
self._exp = Experiment(api_key=api_key, **kwargs)
if tags is not None:
self._exp.add_tags(tags)
self._exp.set_code("Code logging is impossible because there are too many files ...")
self._exp.log_dependency('tensorpack', __git_version__)
@property
def experiment(self):
"""
Returns: the :class:`comet_ml.Experiment` instance.
"""
return self._exp
def _before_train(self):
self._exp.set_model_graph(tf.get_default_graph())
def process_scalar(self, name, val):
self._exp.log_metric(name, val, step=self.global_step)
def _after_train(self):
self._exp.end()
def _after_epoch(self):
self._exp.log_epoch_end(self.epoch_num)
...@@ -8,7 +8,7 @@ import six ...@@ -8,7 +8,7 @@ import six
import tensorflow as tf import tensorflow as tf
from six.moves import range from six.moves import range
from ..callbacks import Callback, Callbacks, Monitors, TrainingMonitor from ..callbacks import Callback, Callbacks, Monitors, MonitorBase
from ..callbacks.steps import MaintainStepCounter from ..callbacks.steps import MaintainStepCounter
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
...@@ -186,7 +186,7 @@ class Trainer(object): ...@@ -186,7 +186,7 @@ class Trainer(object):
Args: Args:
callbacks ([Callback]): callbacks ([Callback]):
monitors ([TrainingMonitor]): monitors ([MonitorBase]):
""" """
assert isinstance(callbacks, list), callbacks assert isinstance(callbacks, list), callbacks
assert isinstance(monitors, list), monitors assert isinstance(monitors, list), monitors
...@@ -196,7 +196,7 @@ class Trainer(object): ...@@ -196,7 +196,7 @@ class Trainer(object):
for cb in callbacks: for cb in callbacks:
self.register_callback(cb) self.register_callback(cb)
for cb in self._callbacks: for cb in self._callbacks:
assert not isinstance(cb, TrainingMonitor), "Monitor cannot be pre-registered for now!" assert not isinstance(cb, MonitorBase), "Monitor cannot be pre-registered for now!"
registered_monitors = [] registered_monitors = []
for m in monitors: for m in monitors:
if self.register_callback(m): if self.register_callback(m):
......
...@@ -84,7 +84,7 @@ class TrainConfig(object): ...@@ -84,7 +84,7 @@ class TrainConfig(object):
MergeAllSummaries(), MergeAllSummaries(),
RunUpdateOps()] RunUpdateOps()]
monitors (list[TrainingMonitor]): Defaults to :func:`DEFAULT_MONITORS()`. monitors (list[MonitorBase]): Defaults to :func:`DEFAULT_MONITORS()`.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()` session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`. with the config returned by :func:`tfutils.get_default_sess_config()`.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment