Commit 563b9cd6 authored by Yuxin Wu's avatar Yuxin Wu

CallbackFactory. fix style. add travis email

parent 93908ddc
......@@ -16,4 +16,8 @@ script:
- cd examples && flake8 .
notifications:
email: false
email:
recipients:
- ppwwyyxxc@gmail.com
on_success: never
on_failure: change # default: always
......@@ -24,6 +24,7 @@ _CM = plt.get_cmap('jet')
14: background
"""
def colorize(img, heatmap):
""" img: bgr, [0,255]
heatmap: [0,1]
......
......@@ -6,22 +6,13 @@ import tensorflow as tf
from abc import ABCMeta
import six
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback']
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory']
@six.add_metaclass(ABCMeta)
class Callback(object):
""" Base class for all callbacks """
def before_train(self):
"""
Called right before the first iteration.
"""
self._before_train()
def _before_train(self):
pass
def setup_graph(self, trainer):
"""
Called before finalizing the graph.
......@@ -40,13 +31,13 @@ class Callback(object):
def _setup_graph(self):
pass
def after_train(self):
def before_train(self):
"""
Called after training.
Called right before the first iteration.
"""
self._after_train()
self._before_train()
def _after_train(self):
def _before_train(self):
pass
def trigger_step(self):
......@@ -68,6 +59,15 @@ class Callback(object):
def _trigger_epoch(self):
pass
def after_train(self):
"""
Called after training.
"""
self._after_train()
def _after_train(self):
pass
def __str__(self):
return type(self).__name__
......@@ -127,3 +127,35 @@ class PeriodicCallback(ProxyCallback):
def __str__(self):
return "Periodic-" + str(self.cb)
class CallbackFactory(Callback):
"""
Create a callback with some lambdas.
"""
def __init__(self, setup_graph=None, before_train=None,
trigger_epoch=None, after_train=None):
"""
Each lambda takes ``self`` as the only argument.
"""
self._cb_setup_graph = setup_graph
self._cb_before_train = before_train
self._cb_trigger_epoch = trigger_epoch
self._cb_after_train = after_train
def _setup_graph(self):
if self._cb_setup_graph:
self._cb_setup_graph(self)
def _before_train(self):
if self._cb_before_train:
self._cb_before_train(self)
def _trigger_epoch(self):
if self._cb_trigger_epoch:
self._cb_trigger_epoch(self)
def _after_train(self):
if self._cb_after_train:
self._cb_after_train(self)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment