Commit 0fda5f71 authored by Yuxin Wu's avatar Yuxin Wu

merge triggerable to callbacks.

parent c088e2a6
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from abc import ABCMeta, abstractmethod from abc import ABCMeta
import six import six
from ..tfutils.common import get_op_or_tensor_by_name from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable'] __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -31,6 +31,7 @@ class Callback(object): ...@@ -31,6 +31,7 @@ class Callback(object):
.. automethod:: _after_run .. automethod:: _after_run
.. automethod:: _trigger_step .. automethod:: _trigger_step
.. automethod:: _trigger_epoch .. automethod:: _trigger_epoch
.. automethod:: _trigger
.. automethod:: _after_train .. automethod:: _after_train
""" """
...@@ -111,7 +112,7 @@ class Callback(object): ...@@ -111,7 +112,7 @@ class Callback(object):
def _trigger_step(self): def _trigger_step(self):
""" """
Called after each :meth:`Trainer.run_step()` completes. Called after each :meth:`Trainer.run_step()` completes. Defaults to no-op.
You can override it to implement, e.g. a ProgressBar. You can override it to implement, e.g. a ProgressBar.
""" """
...@@ -122,7 +123,20 @@ class Callback(object): ...@@ -122,7 +123,20 @@ class Callback(object):
def _trigger_epoch(self): def _trigger_epoch(self):
""" """
Called after the completion of every epoch. Called after the completion of every epoch. Defaults to call ``self.trigger()``
"""
self.trigger()
def trigger(self):
self._trigger()
def _trigger(self):
"""
Override this method to define a general trigger behavior, to be used with trigger schedulers.
Note that the schedulers (e.g. :class:`PeriodicTrigger`) might call this
method both inside an epoch and after an epoch.
When used without the scheduler, this method by default will be called by `trigger_epoch()`.
""" """
pass pass
...@@ -147,40 +161,6 @@ class Callback(object): ...@@ -147,40 +161,6 @@ class Callback(object):
return type(self).__name__ return type(self).__name__
@six.add_metaclass(ABCMeta)
class Triggerable(Callback):
"""
Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()`
which can be called either inside an epoch or between epochs.
Other higher-level wrappers will take the responsibility to determine **when**
to call the trigger.
If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibility and convenience.
.. document private functions
.. automethod:: _trigger
.. automethod:: _trigger_epoch
"""
def trigger(self):
self._trigger()
@abstractmethod
def _trigger(self):
"""
Override this method to define what to trigger.
Note that this method may be called both inside an epoch and after an epoch.
"""
pass
def _trigger_epoch(self):
""" If a :class:`Triggerable` is used as a callback directly,
the default behavior is to run the trigger every epoch."""
self.trigger()
class ProxyCallback(Callback): class ProxyCallback(Callback):
""" A callback which proxy all methods to another callback. """ A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks. It's useful as a base class of callbacks which decorate other callbacks.
......
...@@ -6,14 +6,14 @@ import os ...@@ -6,14 +6,14 @@ import os
import cv2 import cv2
import numpy as np import numpy as np
from .base import Triggerable from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name
__all__ = ['DumpParamAsImage'] __all__ = ['DumpParamAsImage']
class DumpParamAsImage(Triggerable): class DumpParamAsImage(Callback):
""" """
Dump a tensor to image(s) to ``logger.LOG_DIR`` after every epoch. Dump a tensor to image(s) to ``logger.LOG_DIR`` after every epoch.
......
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
""" Graph related callbacks""" """ Graph related callbacks"""
from .base import Triggerable from .base import Callback
__all__ = ['RunOp'] __all__ = ['RunOp']
class RunOp(Triggerable): class RunOp(Callback):
""" Run an Op. """ """ Run an Op. """
def __init__(self, setup_func, run_before=True, run_epoch=True): def __init__(self, setup_func, run_before=True, run_epoch=True):
......
...@@ -20,7 +20,7 @@ from ..tfutils.tower import TowerContext ...@@ -20,7 +20,7 @@ from ..tfutils.tower import TowerContext
from ..train.input_data import TensorInput, FeedInput from ..train.input_data import TensorInput, FeedInput
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder
from .base import Triggerable from .base import Callback
from .inference import Inferencer from .inference import Inferencer
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner', __all__ = ['InferenceRunner', 'FeedfreeInferenceRunner',
...@@ -54,7 +54,7 @@ def summary_inferencer(trainer, infs): ...@@ -54,7 +54,7 @@ def summary_inferencer(trainer, infs):
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InferenceRunnerBase(Triggerable): class InferenceRunnerBase(Callback):
""" Base methods for inference runner""" """ Base methods for inference runner"""
def __init__(self, input, infs, input_names=None, prefix=''): def __init__(self, input, infs, input_names=None, prefix=''):
""" """
......
...@@ -9,7 +9,7 @@ import operator ...@@ -9,7 +9,7 @@ import operator
import six import six
import os import os
from .base import Triggerable from .base import Callback
from ..utils import logger from ..utils import logger
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name
...@@ -107,7 +107,7 @@ class ObjAttrParam(HyperParam): ...@@ -107,7 +107,7 @@ class ObjAttrParam(HyperParam):
return getattr(self.obj, self.attrname) return getattr(self.obj, self.attrname)
class HyperParamSetter(Triggerable): class HyperParamSetter(Callback):
""" """
An abstract base callback to set hyperparameters. An abstract base callback to set hyperparameters.
""" """
......
...@@ -6,13 +6,13 @@ import tensorflow as tf ...@@ -6,13 +6,13 @@ import tensorflow as tf
import os import os
import shutil import shutil
from .base import Triggerable from .base import Callback
from ..utils import logger from ..utils import logger
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Triggerable): class ModelSaver(Callback):
""" """
Save the model every epoch. Save the model every epoch.
""" """
...@@ -67,7 +67,7 @@ class ModelSaver(Triggerable): ...@@ -67,7 +67,7 @@ class ModelSaver(Triggerable):
logger.exception("Exception in ModelSaver.trigger_epoch!") logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Triggerable): class MinSaver(Callback):
""" """
Separately save the model with minimum value of some statistics. Separately save the model with minimum value of some statistics.
""" """
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
import os import os
from .base import Triggerable from .base import Callback
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
__all__ = ['StatPrinter', 'SendStat'] __all__ = ['StatPrinter', 'SendStat']
class StatPrinter(Triggerable): class StatPrinter(Callback):
def __init__(self, print_tag=None): def __init__(self, print_tag=None):
log_deprecated("StatPrinter", log_deprecated("StatPrinter",
"No need to add StatPrinter to callbacks anymore!", "No need to add StatPrinter to callbacks anymore!",
...@@ -19,7 +19,7 @@ class StatPrinter(Triggerable): ...@@ -19,7 +19,7 @@ class StatPrinter(Triggerable):
# TODO make it into monitor? # TODO make it into monitor?
class SendStat(Triggerable): class SendStat(Callback):
""" """
Execute a command with some specific stats. Execute a command with some specific stats.
This is useful for, e.g. building a custom statistics monitor. This is useful for, e.g. building a custom statistics monitor.
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .base import Callback, Triggerable from .base import Callback
__all__ = ['MovingAverageSummary', 'MergeAllSummaries'] __all__ = ['MovingAverageSummary', 'MergeAllSummaries']
...@@ -32,7 +32,7 @@ class MovingAverageSummary(Callback): ...@@ -32,7 +32,7 @@ class MovingAverageSummary(Callback):
return [self.ema_op] return [self.ema_op]
class MergeAllSummaries(Triggerable): class MergeAllSummaries(Callback):
""" """
Evaluate all summaries by `tf.summary.merge_all`, and write to logs. Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
""" """
...@@ -70,15 +70,6 @@ class MergeAllSummaries(Triggerable): ...@@ -70,15 +70,6 @@ class MergeAllSummaries(Triggerable):
return return
self.trainer.monitors.put_summary(summary) self.trainer.monitors.put_summary(summary)
def _summary_run_alone(self): def _trigger(self):
summary = self.summary_op.eval() summary = self.summary_op.eval()
self.trainer.monitors.put_summary(summary) self.trainer.monitors.put_summary(summary)
def _trigger_epoch(self):
if self._run_alone:
self._summary_run_alone()
def _trigger(self):
assert self._run_alone, \
"MergeAllSummaries can be used as a trigger only if run_alone=True."
self._summary_run_alone()
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# File: trigger.py # File: trigger.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ProxyCallback, Triggerable from .base import ProxyCallback, Callback
from ..utils.develop import log_deprecated
__all__ = ['PeriodicTrigger', 'PeriodicCallback'] __all__ = ['PeriodicTrigger', 'PeriodicCallback']
...@@ -11,12 +12,12 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback'] ...@@ -11,12 +12,12 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback']
class PeriodicTrigger(ProxyCallback): class PeriodicTrigger(ProxyCallback):
""" """
Trigger a :class:`Triggerable` callback every k steps or every k epochs. Schedule to trigger a callback every k steps or every k epochs by its ``_trigger()`` method.
""" """
def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None): def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None):
""" """
Args: Args:
triggerable (Triggerable): a Triggerable instance. triggerable (Callback): a Callback instance with a _trigger method to be called.
every_k_steps (int): trigger when ``local_step % k == 0``. Set to every_k_steps (int): trigger when ``local_step % k == 0``. Set to
None to disable. None to disable.
every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to
...@@ -24,7 +25,7 @@ class PeriodicTrigger(ProxyCallback): ...@@ -24,7 +25,7 @@ class PeriodicTrigger(ProxyCallback):
every_k_steps and every_k_epochs can be both set, but cannot be both NOne. every_k_steps and every_k_epochs can be both set, but cannot be both NOne.
""" """
assert isinstance(triggerable, Triggerable), type(triggerable) assert isinstance(triggerable, Callback), type(triggerable)
super(PeriodicTrigger, self).__init__(triggerable) super(PeriodicTrigger, self).__init__(triggerable)
assert (every_k_epochs is not None) or (every_k_steps is not None), \ assert (every_k_epochs is not None) or (every_k_steps is not None), \
"every_k_steps and every_k_epochs cannot be both None!" "every_k_steps and every_k_epochs cannot be both None!"
...@@ -54,9 +55,8 @@ class PeriodicCallback(ProxyCallback): ...@@ -54,9 +55,8 @@ class PeriodicCallback(ProxyCallback):
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch` Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
method is called. method is called.
Note that this wrapper will proxy the :meth:`trigger_step` method as-is. This wrapper is legacy. It will only proxy the :meth:`trigger_step` method as-is.
To schedule a :class:`Triggerable` callback more frequent than once per To be able to schedule a callback more frequent than once per epoch, use :class:`PeriodicTrigger` instead.
epoch, use :class:`PeriodicTrigger` instead.
""" """
def __init__(self, cb, period): def __init__(self, cb, period):
...@@ -67,6 +67,7 @@ class PeriodicCallback(ProxyCallback): ...@@ -67,6 +67,7 @@ class PeriodicCallback(ProxyCallback):
""" """
super(PeriodicCallback, self).__init__(cb) super(PeriodicCallback, self).__init__(cb)
self.period = int(period) self.period = int(period)
log_deprecated("PeriodicCallback", "Use the more powerful `PeriodicTrigger`.")
def _trigger_epoch(self): def _trigger_epoch(self):
if self.epoch_num % self.period == 0: if self.epoch_num % self.period == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment