Commit 0018fc13 authored by Yuxin Wu's avatar Yuxin Wu

Add PeriodicCallback (fix #651)

parent a5dc2dd8
......@@ -29,9 +29,9 @@ class RunOp(Callback):
op (tf.Operation or function): an Op, or a function that returns the Op in the graph.
The function will be called later (in the `setup_graph` callback).
run_before (bool): run the Op before training
run_as_trigger (bool): run the Op on every trigger
run_as_trigger (bool): run the Op on every :meth:`trigger()` call.
run_step (bool): run the Op every step (along with training)
verbose (bool): pring logs when the op is run.
verbose (bool): print logs when the op is run.
Examples:
The `DQN Example
......
......@@ -266,7 +266,7 @@ class JSONWriter(TrainingMonitor):
self._fname = os.path.join(self._dir, self.FILENAME)
if os.path.isfile(self._fname):
logger.info("Found existing JSON at {}, will append to it.".format(self._fname))
logger.info("Found JSON at {}, will append to it.".format(self._fname))
with open(self._fname) as f:
self._stats = json.load(f)
assert isinstance(self._stats, list), type(self._stats)
......@@ -277,7 +277,7 @@ class JSONWriter(TrainingMonitor):
pass
else:
# TODO is this a good idea?
logger.info("Found training history from JSON, now starting from epoch number {}.".format(epoch))
logger.info("Found history statistics from JSON. Rename the first epoch of this training to epoch #{}.".format(epoch))
self.trainer.loop.starting_epoch = epoch
self.trainer.loop._epoch_num = epoch - 1
else:
......@@ -289,16 +289,19 @@ class JSONWriter(TrainingMonitor):
def _trigger_step(self):
# will do this in trigger_epoch
if self.local_step != self.trainer.steps_per_epoch - 1:
self._push()
self._trigger()
def _trigger_epoch(self):
self._push()
self._trigger()
def process_scalar(self, name, val):
self._stat_now[name] = val
def _push(self):
""" Note that this method is idempotent"""
def _trigger(self):
"""
Add stats to json and dump to disk.
Note that this method is idempotent.
"""
if len(self._stat_now):
self._stat_now['epoch_num'] = self.epoch_num
self._stat_now['global_step'] = self.global_step
......@@ -354,20 +357,21 @@ class ScalarPrinter(TrainingMonitor):
if self._enable_step:
if self.local_step != self.trainer.steps_per_epoch - 1:
# not the last step
self._print_stat()
self._trigger()
else:
if not self._enable_epoch:
self._print_stat()
self._trigger()
# otherwise, will print them together
def _trigger_epoch(self):
if self._enable_epoch:
self._print_stat()
self._trigger()
def process_scalar(self, name, val):
self._dic[name] = float(val)
def _print_stat(self):
def _trigger(self):
# Print stats here
def match_regex_list(regexs, name):
for r in regexs:
if r.search(name) is not None:
......@@ -438,12 +442,9 @@ class SendMonitorData(TrainingMonitor):
self.dic[name] = val
def _trigger_step(self):
self._try_send()
def _trigger_epoch(self):
self._try_send()
self._trigger()
def _try_send(self):
def _trigger(self):
try:
v = {k: self.dic[k] for k in self.names}
except KeyError:
......
......@@ -15,7 +15,7 @@ __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Callback):
"""
Save the model every epoch.
Save the model once triggered.
"""
def __init__(self, max_to_keep=10,
......
......@@ -33,7 +33,7 @@ class InjectShell(Callback):
"""
Allow users to create a specific file as a signal to pause
and iteratively debug the training.
When triggered, it detects whether the file exists, and opens an
Once triggered, it detects whether the file exists, and opens an
IPython/pdb shell if yes.
In the shell, `self` is this callback, `self.trainer` is the trainer, and
from that you can access everything else.
......@@ -71,7 +71,7 @@ class InjectShell(Callback):
class DumpParamAsImage(Callback):
"""
Dump a tensor to image(s) to ``logger.get_logger_dir()`` after every epoch.
Dump a tensor to image(s) to ``logger.get_logger_dir()`` once triggered.
Note that it requires the tensor is directly evaluable, i.e. either inputs
are not its dependency (e.g. the weights of the model), or the inputs are
......
......@@ -5,15 +5,16 @@
from .base import ProxyCallback, Callback
__all__ = ['PeriodicTrigger', 'PeriodicRunHooks', 'EnableCallbackIf']
__all__ = ['PeriodicTrigger', 'PeriodicCallback', 'EnableCallbackIf']
class PeriodicTrigger(ProxyCallback):
"""
Schedule to trigger a callback every k global steps or every k epochs by its ``trigger()`` method.
Schedule to trigger a callback every k global steps or every k epochs by its :meth:`trigger()` method.
Most existing callbacks which do something every epoch are implemented
with :meth:`trigger()` method.
Note that it does not touch other methods (``before/after_run``,
``trigger_step``, etc).
All other methods (``before/after_run``, ``trigger_step``, etc) are unaffected.
"""
_chief_only = False
......@@ -21,11 +22,11 @@ class PeriodicTrigger(ProxyCallback):
def __init__(self, triggerable, every_k_steps=None, every_k_epochs=None):
"""
Args:
triggerable (Callback): a Callback instance with a _trigger method to be called.
triggerable (Callback): a Callback instance with a trigger method to be called.
every_k_steps (int): trigger when ``global_step % k == 0``. Set to
None to disable.
None to ignore.
every_k_epochs (int): trigger when ``epoch_num % k == 0``. Set to
None to disable.
None to ignore.
every_k_steps and every_k_epochs can be both set, but cannot be both None.
"""
......@@ -37,6 +38,7 @@ class PeriodicTrigger(ProxyCallback):
self._epoch_k = every_k_epochs
def _trigger_step(self):
self.cb.trigger_step()
if self._step_k is None:
return
if self.global_step % self._step_k == 0:
......@@ -54,7 +56,7 @@ class PeriodicTrigger(ProxyCallback):
class PeriodicRunHooks(ProxyCallback):
"""
Schedule the ``{before,after}_run`` methods of a callback every k global steps.
Enable the ``{before,after}_run`` methods of a callback every k global steps.
All other methods are untouched.
"""
......@@ -87,9 +89,10 @@ class PeriodicRunHooks(ProxyCallback):
class EnableCallbackIf(ProxyCallback):
"""
Enable ``{before,after}_epoch``, ``{before,after}_run``, ``trigger*``
Enable ``{before,after}_epoch``, ``{before,after}_run``,
``trigger_{epoch,step}``
methods of a callback, only when some condition satisfies.
The other methods will be called the same.
The other methods are unaffected.
Note:
If you use ``{before,after}_run``,
......@@ -126,10 +129,6 @@ class EnableCallbackIf(ProxyCallback):
if self._pred(self):
super(EnableCallbackIf, self)._after_epoch()
def _trigger(self):
if self._pred(self):
super(EnableCallbackIf, self)._trigger()
def _trigger_epoch(self):
if self._pred(self):
super(EnableCallbackIf, self)._trigger_epoch()
......@@ -140,3 +139,48 @@ class EnableCallbackIf(ProxyCallback):
def __str__(self):
return "EnableCallbackIf-" + str(self.cb)
class PeriodicCallback(EnableCallbackIf):
"""
Make the calls to the following methods of a callback **less** frequent:
``{before,after}_epoch``, ``{before,after}_run``, ``trigger_{epoch,step}``.
These methods will be enabled only when ``global_step % every_k_steps == 0`
or ``epoch_num % every_k_epochs == 0``. The other methods are unaffected.
Note that this can only makes a callback **less** frequent than before.
:class:`PeriodicTrigger` can
make a callback which supports :meth:`trigger()` method more frequent than before.
"""
_chief_only = False
def __init__(self, callback, every_k_steps=None, every_k_epochs=None):
"""
Args:
callback (Callback): a Callback instance.
every_k_steps (int): enable the callback when ``global_step % k == 0``. Set to
None to ignore.
every_k_epochs (int): enable the callback when ``epoch_num % k == 0``. Set to
None to ignore.
every_k_steps and every_k_epochs can be both set, but cannot be both None.
"""
assert isinstance(callback, Callback), type(callback)
assert (every_k_epochs is not None) or (every_k_steps is not None), \
"every_k_steps and every_k_epochs cannot be both None!"
self._step_k = every_k_steps
self._epoch_k = every_k_epochs
super(PeriodicCallback, self).__init__(callback, PeriodicCallback.predicate)
def predicate(self):
if self._step_k is not None and self.global_step % self._step_k == 0:
return True
if self._epoch_k is not None and self.epoch_num % self._epoch_k == 0:
return True
return False
def __str__(self):
return "PeriodicCallback-" + str(self.cb)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment