Commit b4a5b6e9 authored by Yuxin Wu's avatar Yuxin Wu

proxy callback

parent 4aaf06ca
...@@ -87,8 +87,23 @@ class Callback(object): ...@@ -87,8 +87,23 @@ class Callback(object):
def _trigger_epoch(self): def _trigger_epoch(self):
pass pass
class ProxyCallback(Callback):
def __init__(self, cb):
self.cb = cb
def _before_train(self):
self.cb.before_train()
def _setup_graph(self):
self.cb.setup_graph(self.trainer)
class PeriodicCallback(Callback): def _after_train(self):
self.cb.after_train()
def _trigger_epoch(self):
self.cb.trigger_epoch()
class PeriodicCallback(ProxyCallback):
""" """
A callback to be triggered after every `period` epochs. A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step Doesn't work for trigger_step
...@@ -98,15 +113,9 @@ class PeriodicCallback(Callback): ...@@ -98,15 +113,9 @@ class PeriodicCallback(Callback):
:param cb: a `Callback` :param cb: a `Callback`
:param period: int :param period: int
""" """
self.cb = cb super(PeriodicCallback, self).__init__(self, cb)
self.period = int(period) self.period = int(period)
def _before_train(self):
self.cb.before_train(self.trainer)
def _after_train(self):
self.cb.after_train()
def _trigger_epoch(self): def _trigger_epoch(self):
self.cb.epoch_num = self.epoch_num - 1 self.cb.epoch_num = self.epoch_num - 1
if self.epoch_num % self.period == 0: if self.epoch_num % self.period == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment