Commit 0fda5f71 authored by Yuxin Wu's avatar Yuxin Wu

merge triggerable to callbacks.

parent c088e2a6
......@@ -3,11 +3,11 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
import six
from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']
@six.add_metaclass(ABCMeta)
......@@ -31,6 +31,7 @@ class Callback(object):
.. automethod:: _after_run
.. automethod:: _trigger_step
.. automethod:: _trigger_epoch
.. automethod:: _trigger
.. automethod:: _after_train
"""
......@@ -111,7 +112,7 @@ class Callback(object):
def _trigger_step(self):
"""
Called after each :meth:`Trainer.run_step()` completes.
Called after each :meth:`Trainer.run_step()` completes. Defaults to no-op.
You can override it to implement, e.g. a ProgressBar.
"""
......@@ -122,7 +123,20 @@ class Callback(object):
def _trigger_epoch(self):
"""
Called after the completion of every epoch.
Called after the completion of every epoch. Defaults to call ``self.trigger()``
"""
self.trigger()
def trigger(self):
self._trigger()
def _trigger(self):
"""
Override this method to define a general trigger behavior, to be used with trigger schedulers.
Note that the schedulers (e.g. :class:`PeriodicTrigger`) might call this
method both inside an epoch and after an epoch.
When used without the scheduler, this method by default will be called by `trigger_epoch()`.
"""
pass
......@@ -147,40 +161,6 @@ class Callback(object):
return type(self).__name__
@six.add_metaclass(ABCMeta)
class Triggerable(Callback):
"""
Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()`
which can be called either inside an epoch or between epochs.
Other higher-level wrappers will take the responsibility to determine **when**
to call the trigger.
If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibility and convenience.
.. document private functions
.. automethod:: _trigger
.. automethod:: _trigger_epoch
"""
def trigger(self):
self._trigger()
@abstractmethod
def _trigger(self):
"""
Override this method to define what to trigger.
Note that this method may be called both inside an epoch and after an epoch.
"""
pass
def _trigger_epoch(self):
""" If a :class:`Triggerable` is used as a callback directly,
the default behavior is to run the trigger every epoch."""
self.trigger()
class ProxyCallback(Callback):
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
......
......@@ -6,14 +6,14 @@ import os
import cv2
import numpy as np
from .base import Triggerable
from .base import Callback
from ..utils import logger
from ..tfutils import get_op_tensor_name
__all__ = ['DumpParamAsImage']
class DumpParamAsImage(Triggerable):
class DumpParamAsImage(Callback):
"""
Dump a tensor to image(s) to ``logger.LOG_DIR`` after every epoch.
......
......@@ -5,12 +5,12 @@
""" Graph related callbacks"""
from .base import Triggerable
from .base import Callback
__all__ = ['RunOp']
class RunOp(Triggerable):
class RunOp(Callback):
""" Run an Op. """
def __init__(self, setup_func, run_before=True, run_epoch=True):
......
......@@ -20,7 +20,7 @@ from ..tfutils.tower import TowerContext
from ..train.input_data import TensorInput, FeedInput
from ..predict import PredictorTowerBuilder
from .base import Triggerable
from .base import Callback
from .inference import Inferencer
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner',
......@@ -54,7 +54,7 @@ def summary_inferencer(trainer, infs):
@six.add_metaclass(ABCMeta)
class InferenceRunnerBase(Triggerable):
class InferenceRunnerBase(Callback):
""" Base methods for inference runner"""
def __init__(self, input, infs, input_names=None, prefix=''):
"""
......
......@@ -9,7 +9,7 @@ import operator
import six
import os
from .base import Triggerable
from .base import Callback
from ..utils import logger
from ..tfutils import get_op_tensor_name
......@@ -107,7 +107,7 @@ class ObjAttrParam(HyperParam):
return getattr(self.obj, self.attrname)
class HyperParamSetter(Triggerable):
class HyperParamSetter(Callback):
"""
An abstract base callback to set hyperparameters.
"""
......
......@@ -6,13 +6,13 @@ import tensorflow as tf
import os
import shutil
from .base import Triggerable
from .base import Callback
from ..utils import logger
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Triggerable):
class ModelSaver(Callback):
"""
Save the model every epoch.
"""
......@@ -67,7 +67,7 @@ class ModelSaver(Triggerable):
logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Triggerable):
class MinSaver(Callback):
"""
Separately save the model with minimum value of some statistics.
"""
......
......@@ -4,14 +4,14 @@
import os
from .base import Triggerable
from .base import Callback
from ..utils import logger
from ..utils.develop import log_deprecated
__all__ = ['StatPrinter', 'SendStat']
class StatPrinter(Triggerable):
class StatPrinter(Callback):
def __init__(self, print_tag=None):
log_deprecated("StatPrinter",
"No need to add StatPrinter to callbacks anymore!",
......@@ -19,7 +19,7 @@ class StatPrinter(Triggerable):
# TODO make it into monitor?
class SendStat(Triggerable):
class SendStat(Callback):
"""
Execute a command with some specific stats.
This is useful for, e.g. building a custom statistics monitor.
......
......@@ -6,7 +6,7 @@
import tensorflow as tf
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .base import Callback, Triggerable
from .base import Callback
__all__ = ['MovingAverageSummary', 'MergeAllSummaries']
......@@ -32,7 +32,7 @@ class MovingAverageSummary(Callback):
return [self.ema_op]
class MergeAllSummaries(Triggerable):
class MergeAllSummaries(Callback):
"""
Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
"""
......@@ -70,15 +70,6 @@ class MergeAllSummaries(Triggerable):
return
self.trainer.monitors.put_summary(summary)
def _summary_run_alone(self):
def _trigger(self):
summary = self.summary_op.eval()
self.trainer.monitors.put_summary(summary)
def _trigger_epoch(self):
if self._run_alone:
self._summary_run_alone()
def _trigger(self):
assert self._run_alone, \
"MergeAllSummaries can be used as a trigger only if run_alone=True."
self._summary_run_alone()
......@@ -3,7 +3,8 @@
# File: trigger.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ProxyCallback, Triggerable
from .base import ProxyCallback, Callback
from ..utils.develop import log_deprecated
__all__ = ['PeriodicTrigger', 'PeriodicCallback']
......@@ -11,12 +12,12 @@ __all__ = ['PeriodicTrigger', 'PeriodicCallback']
class PeriodicTrigger(ProxyCallback):
"""
Trigger a :class:`Triggerable` callback every k steps or every k epochs.
Schedule to trigger a callback every k steps or every k epochs by its ``_trigger()`` method.
"""
def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None):
"""
Args:
triggerable (Triggerable): a Triggerable instance.
triggerable (Callback): a Callback instance with a _trigger method to be called.
every_k_steps (int): trigger when ``local_step % k == 0``. Set to
None to disable.
every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to
......@@ -24,7 +25,7 @@ class PeriodicTrigger(ProxyCallback):
every_k_steps and every_k_epochs can be both set, but cannot be both NOne.
"""
assert isinstance(triggerable, Triggerable), type(triggerable)
assert isinstance(triggerable, Callback), type(triggerable)
super(PeriodicTrigger, self).__init__(triggerable)
assert (every_k_epochs is not None) or (every_k_steps is not None), \
"every_k_steps and every_k_epochs cannot be both None!"
......@@ -54,9 +55,8 @@ class PeriodicCallback(ProxyCallback):
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
method is called.
Note that this wrapper will proxy the :meth:`trigger_step` method as-is.
To schedule a :class:`Triggerable` callback more frequent than once per
epoch, use :class:`PeriodicTrigger` instead.
This wrapper is legacy. It will only proxy the :meth:`trigger_step` method as-is.
To be able to schedule a callback more frequent than once per epoch, use :class:`PeriodicTrigger` instead.
"""
def __init__(self, cb, period):
......@@ -67,6 +67,7 @@ class PeriodicCallback(ProxyCallback):
"""
super(PeriodicCallback, self).__init__(cb)
self.period = int(period)
log_deprecated("PeriodicCallback", "Use the more powerful `PeriodicTrigger`.")
def _trigger_epoch(self):
if self.epoch_num % self.period == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment