Commit b0fc944d authored by Yuxin Wu's avatar Yuxin Wu

add Triggerable callback

parent 2df3dcf4
......@@ -9,7 +9,7 @@ Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
* 2017/01/25. Argument order of `models.ConcatWith` is changed to follow the API change in
TensorFlow upstream.
TensorFlow upstream. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/2df3dcf401a99fe61c699ad719e95528872d3abe).
* 2017/01/25. `TrainConfig(callbacks=)` now takes a list of `Callback` instances. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/243e957fe6d62a0cfb5728bd77fb3e005d6603e4)
on how to change your code.
* 2017/01/06. `summary.add_moving_summary` now takes any number of positional arguments instead of a list.
......
......@@ -150,14 +150,16 @@ class ProxyCallback(Callback):
class PeriodicCallback(ProxyCallback):
"""
Wrap a callback so that it is triggered after every ``period`` epochs.
Wrap a callback so that after every ``period`` epochs, its :meth:`trigger_epoch`
method is called.
Note that this method will proxy the :meth:`trigger_step` method as-is.
"""
def __init__(self, cb, period):
"""
Args:
cb(Callback): the callback to be triggered periodically
period(int): the period
period(int): the period, the number of epochs for a callback to be triggered.
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
......
......@@ -6,29 +6,34 @@ import os
import cv2
import numpy as np
from .base import Callback
from .trigger import Triggerable
from ..utils import logger
from ..tfutils import get_op_tensor_name
__all__ = ['DumpParamAsImage']
class DumpParamAsImage(Callback):
class DumpParamAsImage(Triggerable):
"""
Dump a variable to image(s) to ``logger.LOG_DIR`` after every epoch.
Dump a tensor to image(s) to ``logger.LOG_DIR`` after every epoch.
Note that it requires the tensor is directly evaluable, i.e. either inputs
are not its dependency (e.g. the weights of the model), or the inputs are
feedfree (in which case this callback will take an extra datapoint from
the input pipeline).
"""
def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
def __init__(self, tensor_name, prefix=None, map_func=None, scale=255, clip=False):
"""
Args:
var_name (str): the name of the variable.
tensor_name (str): the name of the tensor.
prefix (str): the filename prefix for saved images. Defaults to the Op name.
map_func: map the value of the variable to an image or list of
map_func: map the value of the tensor to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity.
scale (float): a multiplier on pixel values, applied after map_func.
clip (bool): whether to clip the result to [0, 255].
"""
op_name, self.var_name = get_op_tensor_name(var_name)
op_name, self.tensor_name = get_op_tensor_name(tensor_name)
self.func = map_func
if prefix is None:
self.prefix = op_name
......@@ -40,10 +45,10 @@ class DumpParamAsImage(Callback):
def _before_train(self):
# TODO might not work for multiGPU?
self.var = self.graph.get_tensor_by_name(self.var_name)
self._tensor = self.graph.get_tensor_by_name(self.tensor_name)
def _trigger_epoch(self):
val = self.trainer.sess.run(self.var)
def _trigger(self):
val = self.trainer.sess.run(self._tensor)
if self.func is not None:
val = self.func(val)
if isinstance(val, list) or val.ndim == 4:
......
......@@ -5,12 +5,12 @@
""" Graph related callbacks"""
from .base import Callback
from .trigger import Triggerable
__all__ = ['RunOp']
class RunOp(Callback):
class RunOp(Triggerable):
""" Run an Op. """
def __init__(self, setup_func, run_before=True, run_epoch=True):
......@@ -36,6 +36,6 @@ class RunOp(Callback):
if self.run_before:
self._op.run()
def _trigger_epoch(self):
def _trigger(self):
if self.run_epoch:
self._op.run()
......@@ -9,7 +9,7 @@ import operator
import six
import os
from .base import Callback
from .trigger import Triggerable
from ..utils import logger
from ..tfutils import get_op_tensor_name
......@@ -111,7 +111,7 @@ class ObjAttrParam(HyperParam):
return getattr(self.obj, self.attrname)
class HyperParamSetter(Callback):
class HyperParamSetter(Triggerable):
"""
An abstract base callback to set hyperparameters in every epoch.
"""
......@@ -160,7 +160,7 @@ class HyperParamSetter(Callback):
"""
return self.param.get_value()
def _trigger_epoch(self):
def _trigger(self):
self._set_param()
def _before_train(self):
......
......@@ -9,11 +9,12 @@ import shutil
from .base import Callback
from ..utils import logger
from ..tfutils.varmanip import get_savename_from_varname
from .trigger import Triggerable
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Callback):
class ModelSaver(Triggerable):
"""
Save the model every epoch.
"""
......@@ -64,7 +65,7 @@ class ModelSaver(Callback):
due to an alternative in a different tower".format(v.name, var_dict[name].name))
return var_dict
def _trigger_epoch(self):
def _trigger(self):
try:
if not self.meta_graph_written:
self.saver.export_meta_graph(
......
......@@ -7,6 +7,7 @@ import operator
import json
from .base import Callback
from .trigger import Triggerable
from ..utils import logger
from ..tfutils.common import get_global_step_value
......@@ -139,7 +140,7 @@ class StatPrinter(Callback):
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
class SendStat(Callback):
class SendStat(Triggerable):
"""
Execute a command with some specific stats.
This is useful for, e.g. building a custom statistics monitor.
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: trigger.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta
import six
from .base import Callback
__all__ = ['Triggerable']
@six.add_metaclass(ABCMeta)
class Triggerable(Callback):
"""
Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()`
which can be triggered either inside an epoch or between epochs.
The higher-level wrapper will take the responsibility to determine when
to trigger.
If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibilty and convenience.
"""
def trigger(self):
"""
Trigger something.
Note that this method may be called both inside an epoch and after an epoch.
Some operations (e.g. writing scalar stats) currently will cause
problems if run inside an epoch. This will be fixed in the future.
"""
# TODO
self._trigger()
@abstractmethod
def _trigger(self):
pass
def _trigger_epoch(self):
""" If used as a callback directly, run the trigger every epoch."""
self.trigger()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment