Commit 4c926eb7 authored by Yuxin Wu's avatar Yuxin Wu

docs of callbacks (#147)

parent f3644ce9
...@@ -8,6 +8,9 @@ so you won't need to look at here very often. ...@@ -8,6 +8,9 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here. TensorFlow itself also changes API and those are not listed here.
+ 2017/02/20. The interface of step callbacks are changed to be the same as `tf.train.SessionRunHook`.
If you haven't written any custom step callbacks, there is nothing to do. Otherwise please refer
to the [existing callbacks](https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/callbacks/steps.py).
+ 2017/02/12. `TrainConfig(optimizer=)` was deprecated. Now optimizer is set in `ModelDesc`. And + 2017/02/12. `TrainConfig(optimizer=)` was deprecated. Now optimizer is set in `ModelDesc`. And
gradient processors become part of an optimizer. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/d1041a77a9c59d8c9abf64f389f3b605d65b483e). gradient processors become part of an optimizer. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/d1041a77a9c59d8c9abf64f389f3b605d65b483e).
* 2017/02/11. `_get_input_vars()` in `ModelDesc` was renamed to `_get_inputs`. `InputVar` was * 2017/02/11. `_get_input_vars()` in `ModelDesc` was renamed to `_get_inputs`. `InputVar` was
......
...@@ -3,5 +3,5 @@ tensorpack.callbacks package ...@@ -3,5 +3,5 @@ tensorpack.callbacks package
.. automodule:: tensorpack.callbacks .. automodule:: tensorpack.callbacks
:members: :members:
:undoc-members: :no-undoc-members:
:show-inheritance: :show-inheritance:
...@@ -15,9 +15,7 @@ class Callback(object): ...@@ -15,9 +15,7 @@ class Callback(object):
""" Base class for all callbacks """ Base class for all callbacks
Attributes: Attributes:
epoch_num(int): the current epoch num, starting from 1. epoch_num(int): the number of the current epoch.
local_step(int): the current local step number (1-based) in the current epoch.
which is also the number of steps that have finished.
global_step(int): the number of global steps that have finished. global_step(int): the number of global steps that have finished.
trainer(Trainer): the trainer. trainer(Trainer): the trainer.
graph(tf.Graph): the graph. graph(tf.Graph): the graph.
...@@ -25,16 +23,18 @@ class Callback(object): ...@@ -25,16 +23,18 @@ class Callback(object):
Note: Note:
These attributes are available only after (and including) These attributes are available only after (and including)
:meth:`_setup_graph`. :meth:`_setup_graph`.
"""
def setup_graph(self, trainer): .. document private functions
.. automethod:: _setup_graph
.. automethod:: _before_train
.. automethod:: _before_run
.. automethod:: _after_run
.. automethod:: _trigger_step
.. automethod:: _trigger_epoch
.. automethod:: _after_train
""" """
Called before finalizing the graph.
Use this callback to setup some ops used in the callback.
Args: def setup_graph(self, trainer):
trainer(Trainer): the trainer which calls the callback
"""
self._steps_per_epoch = trainer.config.steps_per_epoch self._steps_per_epoch = trainer.config.steps_per_epoch
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
...@@ -42,37 +42,30 @@ class Callback(object): ...@@ -42,37 +42,30 @@ class Callback(object):
self._setup_graph() self._setup_graph()
def _setup_graph(self): def _setup_graph(self):
"""
Called before finalizing the graph.
Override this method to setup the ops used in the callback.
This is the same as ``tf.train.SessionRunHook.begin()``.
"""
pass pass
def before_train(self): def before_train(self):
"""
Called right before the first iteration.
"""
self._starting_step = get_global_step_value() self._starting_step = get_global_step_value()
self._before_train() self._before_train()
def _before_train(self): def _before_train(self):
pass
def trigger_step(self):
""" """
Callback to be triggered after every run_step. Called right before the first iteration. The main difference to
""" `setup_graph` is that at this point the graph is finalized and a
self._trigger_step() default session is initialized.
Override this method to, e.g. run some operations under the session.
def _trigger_step(self): This is similar to ``tf.train.SessionRunHook.after_create_session()``, but different:
pass it is called after the session is initialized by :class:`tfutils.SessionInit`.
"""
def after_run(self, run_context, run_values):
self._after_run(run_context, run_values)
def _after_run(self, run_context, run_values):
pass pass
def before_run(self, ctx): def before_run(self, ctx):
"""
Same as ``tf.train.SessionRunHook.before_run``.
"""
fetches = self._before_run(ctx) fetches = self._before_run(ctx)
if fetches is None: if fetches is None:
return None return None
...@@ -91,24 +84,56 @@ class Callback(object): ...@@ -91,24 +84,56 @@ class Callback(object):
return tf.train.SessionRunArgs(fetches=ret) return tf.train.SessionRunArgs(fetches=ret)
def _before_run(self, ctx): def _before_run(self, ctx):
"""
It is called before every ``hooked_sess.run()`` call, and it
registers some extra op/tensors to run in the next call.
This method is the same as ``tf.train.SessionRunHook.before_run``.
Refer to TensorFlow docs for more details.
An extra feature is that you can also simply return a list of names,
instead of a ``tf.train.SessionRunArgs``.
"""
return None return None
def trigger_epoch(self): def after_run(self, run_context, run_values):
self._after_run(run_context, run_values)
def _after_run(self, run_context, run_values):
""" """
Triggered after every epoch. It is called after every ``hooked_sess.run()`` call, and it
processes the values requested by the corresponding :meth:`before_run`.
It is equivalent to ``tf.train.SessionRunHook.after_run()``, refer to
TensorFlow docs for more details.
""" """
pass
def trigger_step(self):
self._trigger_step()
def _trigger_step(self):
"""
Called after each :meth:`Trainer.run_step()` completes.
You can override it to implement, e.g. a ProgressBar.
"""
pass
def trigger_epoch(self):
self._trigger_epoch() self._trigger_epoch()
def _trigger_epoch(self): def _trigger_epoch(self):
"""
Called after the completion of every epoch.
"""
pass pass
def after_train(self): def after_train(self):
"""
Called after training.
"""
self._after_train() self._after_train()
def _after_train(self): def _after_train(self):
"""
Called after training.
"""
pass pass
@property @property
...@@ -134,21 +159,26 @@ class Triggerable(Callback): ...@@ -134,21 +159,26 @@ class Triggerable(Callback):
If an triggerable is used as a callback directly (instead of under other If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after higher-level wrapper to control the trigger), it will by default trigger after
every epoch. This is mainly for backward-compatibilty and convenience. every epoch. This is mainly for backward-compatibilty and convenience.
.. document private functions
.. automethod:: _trigger
.. automethod:: _trigger_epoch
""" """
def trigger(self): def trigger(self):
"""
Trigger something.
Note that this method may be called both inside an epoch and after an epoch.
"""
self._trigger() self._trigger()
@abstractmethod @abstractmethod
def _trigger(self): def _trigger(self):
"""
Override this method to define what to trigger.
Note that this method may be called both inside an epoch and after an epoch.
"""
pass pass
def _trigger_epoch(self): def _trigger_epoch(self):
""" If used as a callback directly, run the trigger every epoch.""" """ If a :class:`Triggerable` is used as a callback directly,
the default behavior is to run the trigger every epoch."""
self.trigger() self.trigger()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment