Commit 563b9cd6 authored by Yuxin Wu's avatar Yuxin Wu

CallbackFactory. fix style. add travis email

parent 93908ddc
...@@ -16,4 +16,8 @@ script: ...@@ -16,4 +16,8 @@ script:
- cd examples && flake8 . - cd examples && flake8 .
notifications: notifications:
email: false email:
recipients:
- ppwwyyxxc@gmail.com
on_success: never
on_failure: change # default: always
...@@ -24,6 +24,7 @@ _CM = plt.get_cmap('jet') ...@@ -24,6 +24,7 @@ _CM = plt.get_cmap('jet')
14: background 14: background
""" """
def colorize(img, heatmap): def colorize(img, heatmap):
""" img: bgr, [0,255] """ img: bgr, [0,255]
heatmap: [0,1] heatmap: [0,1]
......
...@@ -6,22 +6,13 @@ import tensorflow as tf ...@@ -6,22 +6,13 @@ import tensorflow as tf
from abc import ABCMeta from abc import ABCMeta
import six import six
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback'] __all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback', 'CallbackFactory']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Callback(object): class Callback(object):
""" Base class for all callbacks """ """ Base class for all callbacks """
def before_train(self):
"""
Called right before the first iteration.
"""
self._before_train()
def _before_train(self):
pass
def setup_graph(self, trainer): def setup_graph(self, trainer):
""" """
Called before finalizing the graph. Called before finalizing the graph.
...@@ -40,13 +31,13 @@ class Callback(object): ...@@ -40,13 +31,13 @@ class Callback(object):
def _setup_graph(self): def _setup_graph(self):
pass pass
def after_train(self): def before_train(self):
""" """
Called after training. Called right before the first iteration.
""" """
self._after_train() self._before_train()
def _after_train(self): def _before_train(self):
pass pass
def trigger_step(self): def trigger_step(self):
...@@ -68,6 +59,15 @@ class Callback(object): ...@@ -68,6 +59,15 @@ class Callback(object):
def _trigger_epoch(self): def _trigger_epoch(self):
pass pass
def after_train(self):
"""
Called after training.
"""
self._after_train()
def _after_train(self):
pass
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
...@@ -127,3 +127,35 @@ class PeriodicCallback(ProxyCallback): ...@@ -127,3 +127,35 @@ class PeriodicCallback(ProxyCallback):
def __str__(self): def __str__(self):
return "Periodic-" + str(self.cb) return "Periodic-" + str(self.cb)
class CallbackFactory(Callback):
"""
Create a callback with some lambdas.
"""
def __init__(self, setup_graph=None, before_train=None,
trigger_epoch=None, after_train=None):
"""
Each lambda takes ``self`` as the only argument.
"""
self._cb_setup_graph = setup_graph
self._cb_before_train = before_train
self._cb_trigger_epoch = trigger_epoch
self._cb_after_train = after_train
def _setup_graph(self):
if self._cb_setup_graph:
self._cb_setup_graph(self)
def _before_train(self):
if self._cb_before_train:
self._cb_before_train(self)
def _trigger_epoch(self):
if self._cb_trigger_epoch:
self._cb_trigger_epoch(self)
def _after_train(self):
if self._cb_after_train:
self._cb_after_train(self)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment