Commit 3657bbd7 authored by Yuxin Wu's avatar Yuxin Wu

make more callbacks triggerable

parent bc0b7c63
......@@ -3,11 +3,11 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from abc import ABCMeta
from abc import ABCMeta, abstractmethod
import six
from ..tfutils.common import get_op_or_tensor_by_name, get_global_step_value
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory']
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
@six.add_metaclass(ABCMeta)
......@@ -128,6 +128,39 @@ class Callback(object):
return type(self).__name__
@six.add_metaclass(ABCMeta)
class Triggerable(Callback):
"""
Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()`
which can be triggered either inside an epoch or between epochs.
The higher-level wrapper will take the responsibility to determine when
to trigger.
If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibilty and convenience.
"""
def trigger(self):
"""
Trigger something.
Note that this method may be called both inside an epoch and after an epoch.
Some operations (e.g. writing scalar stats) currently will cause
problems if run inside an epoch. This will be fixed in the future.
"""
# TODO
self._trigger()
@abstractmethod
def _trigger(self):
pass
def _trigger_epoch(self):
""" If used as a callback directly, run the trigger every epoch."""
self.trigger()
class ProxyCallback(Callback):
""" A callback which proxy all methods to another callback.
It's useful as a base class of callbacks which decorate other callbacks.
......@@ -160,34 +193,6 @@ class ProxyCallback(Callback):
return "Proxy-" + str(self.cb)
class PeriodicCallback(ProxyCallback):
"""
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
method is called.
Note that this method will proxy the :meth:`trigger_step` method as-is.
"""
def __init__(self, cb, period):
"""
Args:
cb(Callback): the callback to be triggered periodically
period(int): the period, the number of epochs for a callback to be triggered.
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
"""
super(PeriodicCallback, self).__init__(cb)
self.period = int(period)
def _trigger_epoch(self):
if self.epoch_num % self.period == 0:
self.cb.trigger_epoch()
def __str__(self):
return "Periodic-" + str(self.cb)
class CallbackFactory(Callback):
"""
Create a callback with some lambdas.
......
......@@ -6,7 +6,7 @@ import os
import cv2
import numpy as np
from .trigger import Triggerable
from .base import Triggerable
from ..utils import logger
from ..tfutils import get_op_tensor_name
......
......@@ -5,7 +5,7 @@
""" Graph related callbacks"""
from .trigger import Triggerable
from .base import Triggerable
__all__ = ['RunOp']
......
......@@ -16,7 +16,7 @@ from ..tfutils import TowerContext
from ..train.input_data import FeedfreeInput
from ..predict import build_prediction_graph
from .base import Callback
from .base import Triggerable
from .inference import Inferencer
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner']
......@@ -63,7 +63,7 @@ def summary_inferencer(trainer, infs):
trainer.add_scalar_summary(k, v)
class InferenceRunner(Callback):
class InferenceRunner(Triggerable):
"""
A callback that runs a list of :class:`Inferencer` on some
:class:`DataFlow`.
......@@ -128,7 +128,7 @@ class InferenceRunner(Callback):
self.inf_to_tensors = [find_tensors(t) for t in dispatcer.get_names_for_each_entry()]
# list of list of IOTensor
def _trigger_epoch(self):
def _trigger(self):
for inf in self.infs:
inf.before_inference()
......@@ -147,7 +147,7 @@ class InferenceRunner(Callback):
summary_inferencer(self.trainer, self.infs)
class FeedfreeInferenceRunner(Callback):
class FeedfreeInferenceRunner(Triggerable):
""" A callback that runs a list of :class:`Inferencer` on some
:class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading
pipeline.
......@@ -231,7 +231,7 @@ class FeedfreeInferenceRunner(Callback):
# list of list of id
self.inf_to_idxs = dispatcer.get_idx_for_each_entry()
def _trigger_epoch(self):
def _trigger(self):
for inf in self.infs:
inf.before_inference()
......
......@@ -9,7 +9,7 @@ import operator
import six
import os
from .trigger import Triggerable
from .base import Triggerable
from ..utils import logger
from ..tfutils import get_op_tensor_name
......
......@@ -6,10 +6,9 @@ import tensorflow as tf
import os
import shutil
from .base import Callback
from .base import Triggerable
from ..utils import logger
from ..tfutils.varmanip import get_savename_from_varname
from .trigger import Triggerable
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......@@ -83,7 +82,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Callback):
class MinSaver(Triggerable):
"""
Separately save the model with minimum value of some statistics.
"""
......@@ -126,7 +125,7 @@ class MinSaver(Callback):
return False
return v > self.min if self.reverse else v < self.min
def _trigger_epoch(self):
def _trigger(self):
if self.min is None or self._need_save():
self.min = self._get_stat()
if self.min:
......
......@@ -6,8 +6,7 @@ import os
import operator
import json
from .base import Callback
from .trigger import Triggerable
from .base import Triggerable
from ..utils import logger
__all__ = ['StatHolder', 'StatPrinter', 'SendStat']
......@@ -110,7 +109,7 @@ class StatHolder(object):
logger.exception("Exception in StatHolder.finalize()!")
class StatPrinter(Callback):
class StatPrinter(Triggerable):
"""
A callback to control what stats to print. Enable by default to print
everything in trainer.stat_holder.
......@@ -132,7 +131,7 @@ class StatPrinter(Callback):
# just try to add this stat earlier so SendStat can use
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
def _trigger_epoch(self):
def _trigger(self):
# by default, add this two stat
self._stat_holder.add_stat('global_step', self.global_step)
self._stat_holder.finalize()
......
......@@ -3,46 +3,10 @@
# File: trigger.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta
import six
from .base import ProxyCallback, Triggerable
from .base import Callback, ProxyCallback
__all__ = ['Triggerable', 'PeriodicTrigger']
@six.add_metaclass(ABCMeta)
class Triggerable(Callback):
"""
Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()`
which can be triggered either inside an epoch or between epochs.
The higher-level wrapper will take the responsibility to determine when
to trigger.
If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibilty and convenience.
"""
def trigger(self):
"""
Trigger something.
Note that this method may be called both inside an epoch and after an epoch.
Some operations (e.g. writing scalar stats) currently will cause
problems if run inside an epoch. This will be fixed in the future.
"""
# TODO
self._trigger()
@abstractmethod
def _trigger(self):
pass
def _trigger_epoch(self):
""" If used as a callback directly, run the trigger every epoch."""
self.trigger()
__all__ = ['PeriodicTrigger', 'PeriodicCallback']
class PeriodicTrigger(ProxyCallback):
......@@ -78,3 +42,37 @@ class PeriodicTrigger(ProxyCallback):
return
if self.epoch_num % self._epoch_k == 0:
self.cb.trigger()
def __str__(self):
return "PeriodicTrigger-" + str(self.cb)
class PeriodicCallback(ProxyCallback):
"""
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
method is called.
Note that this wrapper will proxy the :meth:`trigger_step` method as-is.
To schedule a :class:`Triggerable` callback more frequent than once per
epoch, use :class:`PeriodicTrigger` instead.
"""
def __init__(self, cb, period):
"""
Args:
cb(Callback): the callback to be triggered periodically
period(int): the period, the number of epochs for a callback to be triggered.
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
"""
super(PeriodicCallback, self).__init__(cb)
self.period = int(period)
def _trigger_epoch(self):
if self.epoch_num % self.period == 0:
self.cb.trigger_epoch()
def __str__(self):
return "Periodic-" + str(self.cb)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment