Commit b4a5b6e9 authored by Yuxin Wu's avatar Yuxin Wu

proxy callback

parent 4aaf06ca
......@@ -87,8 +87,23 @@ class Callback(object):
def _trigger_epoch(self):
pass
class ProxyCallback(Callback):
def __init__(self, cb):
self.cb = cb
def _before_train(self):
self.cb.before_train()
def _setup_graph(self):
self.cb.setup_graph(self.trainer)
class PeriodicCallback(Callback):
def _after_train(self):
self.cb.after_train()
def _trigger_epoch(self):
self.cb.trigger_epoch()
class PeriodicCallback(ProxyCallback):
"""
A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step
......@@ -98,15 +113,9 @@ class PeriodicCallback(Callback):
:param cb: a `Callback`
:param period: int
"""
self.cb = cb
super(PeriodicCallback, self).__init__(self, cb)
self.period = int(period)
def _before_train(self):
self.cb.before_train(self.trainer)
def _after_train(self):
self.cb.after_train()
def _trigger_epoch(self):
self.cb.epoch_num = self.epoch_num - 1
if self.epoch_num % self.period == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment