Commit fabb7e7e authored by Yuxin Wu's avatar Yuxin Wu

update docs about callbacks

parent 2c129ded
# Callbacks # Callbacks
Callback is an interface to do __everything else__ besides the training iterations.
Apart from the actual training iterations that minimize the cost, Apart from the actual training iterations that minimize the cost,
you almost surely would like to do something else. you almost surely would like to do something else.
Callbacks are such an interface to describe what to do besides the
training iterations.
There are several places where you might want to do something else: There are several places where you might want to do something else:
* Before the training has started (e.g. initialize the saver, dump the graph) * Before the training has started (e.g. initialize the saver, dump the graph)
......
...@@ -8,15 +8,16 @@ def train(self): ...@@ -8,15 +8,16 @@ def train(self):
callbacks.setup_graph() callbacks.setup_graph()
# ... create session, initialize session, finalize graph ... # ... create session, initialize session, finalize graph ...
# start training: # start training:
callbacks.before_train() with sess.as_default():
for epoch in range(epoch_start, epoch_end): callbacks.before_train()
callbacks.before_epoch() for epoch in range(epoch_start, epoch_end):
for step in range(steps_per_epoch): callbacks.before_epoch()
self.run_step() # callbacks.{before,after}_run are hooked with session for step in range(steps_per_epoch):
callbacks.trigger_step() self.run_step() # callbacks.{before,after}_run are hooked with session
callbacks.after_epoch() callbacks.trigger_step()
callbacks.trigger_epoch() callbacks.after_epoch()
callbacks.after_train() callbacks.trigger_epoch()
callbacks.after_train()
``` ```
Note that at each place, each callback will be called in the order they are given to the trainer. Note that at each place, each callback will be called in the order they are given to the trainer.
...@@ -39,12 +40,12 @@ you can use TF methods such as ...@@ -39,12 +40,12 @@ you can use TF methods such as
If you're using a `TowerTrainer` instance, more tools are available: If you're using a `TowerTrainer` instance, more tools are available:
* Use `self.trainer.tower_func.towers` to access the * Use `self.trainer.tower_func.towers` to access the
[tower handles](../modules/tfutils.html#tensorpack.tfutils.tower.TowerTensorHandles), [tower handles](../modules/tfutils.html#tensorpack.tfutils.tower.TowerTensorHandles),
and therefore the tensors in each tower. and therefore the tensors in each tower.
* [self.get_tensors_maybe_in_tower()](../modules/callbacks.html#tensorpack.callbacks.Callback.get_tensors_maybe_in_tower) * [self.get_tensors_maybe_in_tower()](../modules/callbacks.html#tensorpack.callbacks.Callback.get_tensors_maybe_in_tower)
is a helper function to access tensors in the first training tower. is a helper function to access tensors in the first training tower.
* [self.trainer.get_predictor()](../modules/train.html#tensorpack.train.TowerTrainer.get_predictor) * [self.trainer.get_predictor()](../modules/train.html#tensorpack.train.TowerTrainer.get_predictor)
is a helper function to create a callable under inference mode. is a helper function to create a callable under inference mode.
* `_before_train(self)` * `_before_train(self)`
...@@ -98,8 +99,21 @@ to let this method run every k steps or every k epochs. ...@@ -98,8 +99,21 @@ to let this method run every k steps or every k epochs.
* Access tensors / ops in either training / inference mode (need to create them in `_setup_graph`). * Access tensors / ops in either training / inference mode (need to create them in `_setup_graph`).
* Write stuff to the monitor backend, by `self.trainer.monitors.put_xxx`. * Write stuff to the monitor backend, by `self.trainer.monitors.put_xxx`.
The monitors might direct your events to TensorFlow events file, JSON file, stdout, etc. The monitors might direct your events to TensorFlow events file, JSON file, stdout, etc.
You can get history monitor data as well. See the docs for [Monitors](../../modules/callbacks.html#tensorpack.callbacks.Monitors) You can get history monitor data as well. See the docs for [Monitors](../../modules/callbacks.html#tensorpack.callbacks.Monitors)
* Access the current status of training, such as `epoch_num`, `global_step`. See [here](../../modules/callbacks.html#tensorpack.callbacks.Callback) * Access the current status of training, such as `epoch_num`, `global_step`. See [here](../../modules/callbacks.html#tensorpack.callbacks.Callback)
* Stop training by `raise StopTraining()` (with `from tensorpack.train import StopTraining`). * Stop training by `raise StopTraining()` (with `from tensorpack.train import StopTraining`).
* Anything else that can be done with plain python. * Anything else that can be done with plain python.
### Typical Steps about Writing/Using a Callback
* Define the callback in `__init__`, prepare for it in `_setup_graph, _before_train`.
* Know whether you want to do something __along with__ the session run or not.
If yes, implement the logic with `_{before,after}_run`.
Otherwise, implement in `_trigger`, or `_trigger_step`.
* You can choose to only implement "what to do", and leave "when to do" to
other wrappers such as
[PeriodicTrigger](../../modules/callbacks.html#tensorpack.callbacks.PeriodicTrigger),
[PeriodicRunHooks](../../modules/callbacks.html#tensorpack.callbacks.PeriodicRunHooks),
or [EnableCallbackIf](../../modules/callbacks.html#tensorpack.callbacks.EnableCallbackIf).
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment