Commit 2cf9ad75 authored by Yuxin Wu's avatar Yuxin Wu

before_train option in PeriodicTrigger

parent ba965954
...@@ -15,10 +15,11 @@ class PeriodicTrigger(ProxyCallback): ...@@ -15,10 +15,11 @@ class PeriodicTrigger(ProxyCallback):
with :meth:`trigger()` method. By default the :meth:`trigger()` method will be called every epoch. with :meth:`trigger()` method. By default the :meth:`trigger()` method will be called every epoch.
This wrapper can make the callback run at a different frequency. This wrapper can make the callback run at a different frequency.
All other methods (``before/after_run``, ``trigger_step``, etc) of the given callback are unaffected. All other methods (``before/after_run``, ``trigger_step``, etc) of the given callback
are unaffected. They will still be called as-is.
""" """
def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None): def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None, before_train=False):
""" """
Args: Args:
triggerable (Callback): a Callback instance with a trigger method to be called. triggerable (Callback): a Callback instance with a trigger method to be called.
...@@ -26,15 +27,23 @@ class PeriodicTrigger(ProxyCallback): ...@@ -26,15 +27,23 @@ class PeriodicTrigger(ProxyCallback):
None to ignore. None to ignore.
every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to
None to ignore. None to ignore.
before_train (bool): trigger in the :meth:`before_train` method.
every_k_steps and every_k_epochs can be both set, but cannot be both None. every_k_steps and every_k_epochs can be both set, but cannot be both None unless before_train is True.
""" """
assert isinstance(triggerable, Callback), type(triggerable) assert isinstance(triggerable, Callback), type(triggerable)
super(PeriodicTrigger, self).__init__(triggerable) super(PeriodicTrigger, self).__init__(triggerable)
if before_train is False:
assert (every_k_epochs is not None) or (every_k_steps is not None), \ assert (every_k_epochs is not None) or (every_k_steps is not None), \
"every_k_steps and every_k_epochs cannot be both None!" "Arguments to PeriodicTrigger have disabled the triggerable!"
self._step_k = every_k_steps self._step_k = every_k_steps
self._epoch_k = every_k_epochs self._epoch_k = every_k_epochs
self._before_train = before_train
def _before_train(self):
self.cb.before_train()
if self._before_train:
self.cb.trigger()
def _trigger_step(self): def _trigger_step(self):
self.cb.trigger_step() self.cb.trigger_step()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment