Commit 4c926eb7 authored by Yuxin Wu's avatar Yuxin Wu

docs of callbacks (#147)

parent f3644ce9
......@@ -8,6 +8,9 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
+ 2017/02/20. The interface of step callbacks are changed to be the same as `tf.train.SessionRunHook`.
If you haven't written any custom step callbacks, there is nothing to do. Otherwise please refer
to the [existing callbacks](https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/callbacks/steps.py).
+ 2017/02/12. `TrainConfig(optimizer=)` was deprecated. Now optimizer is set in `ModelDesc`. And
gradient processors become part of an optimizer. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/d1041a77a9c59d8c9abf64f389f3b605d65b483e).
* 2017/02/11. `_get_input_vars()` in `ModelDesc` was renamed to `_get_inputs`. `InputVar` was
......
......@@ -3,5 +3,5 @@ tensorpack.callbacks package
.. automodule:: tensorpack.callbacks
:members:
:undoc-members:
:no-undoc-members:
:show-inheritance:
......@@ -15,9 +15,7 @@ class Callback(object):
""" Base class for all callbacks
Attributes:
epoch_num(int): the current epoch num, starting from 1.
local_step(int): the current local step number (1-based) in the current epoch.
which is also the number of steps that have finished.
epoch_num(int): the number of the current epoch.
global_step(int): the number of global steps that have finished.
trainer(Trainer): the trainer.
graph(tf.Graph): the graph.
......@@ -25,16 +23,18 @@ class Callback(object):
Note:
These attributes are available only after (and including)
:meth:`_setup_graph`.
"""
def setup_graph(self, trainer):
.. document private functions
.. automethod:: _setup_graph
.. automethod:: _before_train
.. automethod:: _before_run
.. automethod:: _after_run
.. automethod:: _trigger_step
.. automethod:: _trigger_epoch
.. automethod:: _after_train
"""
Called before finalizing the graph.
Use this callback to setup some ops used in the callback.
Args:
trainer(Trainer): the trainer which calls the callback
"""
def setup_graph(self, trainer):
self._steps_per_epoch = trainer.config.steps_per_epoch
self.trainer = trainer
self.graph = tf.get_default_graph()
......@@ -42,37 +42,30 @@ class Callback(object):
self._setup_graph()
def _setup_graph(self):
"""
Called before finalizing the graph.
Override this method to setup the ops used in the callback.
This is the same as ``tf.train.SessionRunHook.begin()``.
"""
pass
def before_train(self):
"""
Called right before the first iteration.
"""
self._starting_step = get_global_step_value()
self._before_train()
def _before_train(self):
pass
def trigger_step(self):
"""
Callback to be triggered after every run_step.
"""
self._trigger_step()
Called right before the first iteration. The main difference to
`setup_graph` is that at this point the graph is finalized and a
default session is initialized.
Override this method to, e.g. run some operations under the session.
def _trigger_step(self):
pass
def after_run(self, run_context, run_values):
self._after_run(run_context, run_values)
def _after_run(self, run_context, run_values):
This is similar to ``tf.train.SessionRunHook.after_create_session()``, but different:
it is called after the session is initialized by :class:`tfutils.SessionInit`.
"""
pass
def before_run(self, ctx):
"""
Same as ``tf.train.SessionRunHook.before_run``.
"""
fetches = self._before_run(ctx)
if fetches is None:
return None
......@@ -91,24 +84,56 @@ class Callback(object):
return tf.train.SessionRunArgs(fetches=ret)
def _before_run(self, ctx):
"""
It is called before every ``hooked_sess.run()`` call, and it
registers some extra op/tensors to run in the next call.
This method is the same as ``tf.train.SessionRunHook.before_run``.
Refer to TensorFlow docs for more details.
An extra feature is that you can also simply return a list of names,
instead of a ``tf.train.SessionRunArgs``.
"""
return None
def trigger_epoch(self):
def after_run(self, run_context, run_values):
self._after_run(run_context, run_values)
def _after_run(self, run_context, run_values):
"""
Triggered after every epoch.
It is called after every ``hooked_sess.run()`` call, and it
processes the values requested by the corresponding :meth:`before_run`.
It is equivalent to ``tf.train.SessionRunHook.after_run()``, refer to
TensorFlow docs for more details.
"""
pass
def trigger_step(self):
self._trigger_step()
def _trigger_step(self):
"""
Called after each :meth:`Trainer.run_step()` completes.
You can override it to implement, e.g. a ProgressBar.
"""
pass
def trigger_epoch(self):
self._trigger_epoch()
def _trigger_epoch(self):
"""
Called after the completion of every epoch.
"""
pass
def after_train(self):
"""
Called after training.
"""
self._after_train()
def _after_train(self):
"""
Called after training.
"""
pass
@property
......@@ -134,21 +159,26 @@ class Triggerable(Callback):
If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibilty and convenience.
.. document private functions
.. automethod:: _trigger
.. automethod:: _trigger_epoch
"""
def trigger(self):
"""
Trigger something.
Note that this method may be called both inside an epoch and after an epoch.
"""
self._trigger()
@abstractmethod
def _trigger(self):
"""
Override this method to define what to trigger.
Note that this method may be called both inside an epoch and after an epoch.
"""
pass
def _trigger_epoch(self):
""" If used as a callback directly, run the trigger every epoch."""
""" If a :class:`Triggerable` is used as a callback directly,
the default behavior is to run the trigger every epoch."""
self.trigger()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment