Commit 5ae5f3e5 authored by Yuxin Wu's avatar Yuxin Wu

some docs about callbacks

parent 0630a31c
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
Apart from the actual training iterations that minimize the cost, Apart from the actual training iterations that minimize the cost,
you almost surely would like to do something else during training. you almost surely would like to do something else during training.
Callbacks are such an interface to describe what to do besides the Callbacks are such an interface to describe what to do besides the
training iterations defined by the trainers. training iterations.
There are several places where you might want to do something else: There are several places where you might want to do something else:
...@@ -14,9 +14,13 @@ There are several places where you might want to do something else: ...@@ -14,9 +14,13 @@ There are several places where you might want to do something else:
* Between epochs (e.g. save the model, run some validation) * Between epochs (e.g. save the model, run some validation)
* After the training (e.g. send the model somewhere, send a message to your phone) * After the training (e.g. send the model somewhere, send a message to your phone)
By writing callbacks to implement these tasks, you can reuse the code as long as We found people traditionally tend to write the training loop together with these extra features.
you are using tensorpack trainers. For example, these are the callbacks I used when training This makes the loop lengthy, and the code for the same feature probably get separated.
a ResNet: By writing callbacks to implement what you want to do at each place, tensorpack trainers
will call them at the proper time.
Therefore the code can be reused with one single line, as long as you are using tensorpack trainers.
For example, these are the callbacks I used when training a ResNet:
```python ```python
TrainConfig( TrainConfig(
...@@ -24,8 +28,8 @@ TrainConfig( ...@@ -24,8 +28,8 @@ TrainConfig(
callbacks=[ callbacks=[
# save the model every epoch # save the model every epoch
ModelSaver(), ModelSaver(),
# backup the model with best validation error # backup the model with best validation error
MinSaver('val-error-top1'), MinSaver('val-error-top1'),
# run inference on another Dataflow every epoch, compute top1/top5 classification error and save them in log # run inference on another Dataflow every epoch, compute top1/top5 classification error and save them in log
InferenceRunner(dataset_val, [ InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'), ClassificationError('wrong-top1', 'val-error-top1'),
...@@ -48,8 +52,8 @@ TrainConfig( ...@@ -48,8 +52,8 @@ TrainConfig(
ProgressBar(), ProgressBar(),
# run `tf.summary.merge_all` every epoch and send results to monitors # run `tf.summary.merge_all` every epoch and send results to monitors
MergeAllSummaries(), MergeAllSummaries(),
# run ops in GraphKeys.UPDATE_OPS collection along with training, if any # run ops in GraphKeys.UPDATE_OPS collection along with training, if any
RunUpdateOps(), RunUpdateOps(),
], ],
monitors=[ # monitors are a special kind of callbacks. these are also enabled by default monitors=[ # monitors are a special kind of callbacks. these are also enabled by default
# write all monitor data to tensorboard # write all monitor data to tensorboard
......
## Write a callback ## Write a callback
TODO The places where each callback gets called is demonstrated in this snippet:
```python
def main_loop():
# create graph for the model
callbacks.setup_graph()
# create session, initialize session, finalize graph ...
# start training:
callbacks.before_train()
for epoch in range(epoch_start, epoch_end):
for step in range(steps_per_epoch):
run_step() # callbacks.{before,after}_run are hooked with session
callbacks.trigger_step()
callbacks.trigger_epoch()
callbacks.after_train()
```
You can overwrite any of the following methods to define a new callback:
* `_setup_graph(self)`
To separate between "define" and "run", and also to avoid the common mistake to create ops inside
loops, all changes to the graph should be made in this method. No session has been created at this time.
TODO how to access the tensors already defined.
* `_before_train(self)`
Can be used to run some manual initialization of variables, or start some services for the whole training.
* `_trigger_step(self)`
Do something (including running ops) after each step has finished.
Be careful to only do light work here because it could affect training speed.
* `_before_run(self, ctx)`, `_after_run(self, ctx, values)`
This two are the equivlent of [tf.train.SessionRunHook](https://www.tensorflow.org/api_docs/python/tf/train/SessionRunHook).
Please refer to TensorFlow documentation for detailed API.
They are used to run extra ops / eval extra tensors / feed extra values __along with__ the actual training iteration.
Note the difference between running __along with__ an iteration and running after an iteration.
When you write
```python
def _before_run(self, _):
return tf.train.SessionRunArgs(fetches=my_op)
```
The training loops would become `sess.run([training_op, my_op])`.
This is different from `sess.run(training_op); sess.run(my_op);`,
which is what you would get if you run the op in `_trigger_step`.
* `_trigger_epoch(self)`
Do something after each epoch has finished. Will call `self.trigger()` by default.
* `_trigger(self)`
By default will get called by `_trigger_epoch`,
but you can then customize the scheduling of this callback by
`PeriodicTrigger`, to let this method run every k steps or every k epochs.
* `_after_train(self)`
Do some finalization work.
...@@ -114,7 +114,7 @@ def get_config(): ...@@ -114,7 +114,7 @@ def get_config():
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
PeriodicTrigger( PeriodicTrigger(
RunOp(DQNModel.update_target_param), RunOp(DQNModel.update_target_param, verbose=True),
every_k_steps=10000 // UPDATE_FREQ), # update target network every 10k steps every_k_steps=10000 // UPDATE_FREQ), # update target network every 10k steps
expreplay, expreplay,
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
......
...@@ -106,7 +106,7 @@ class OnlinePredictor(PredictorBase): ...@@ -106,7 +106,7 @@ class OnlinePredictor(PredictorBase):
output_tensors (list): list of names. output_tensors (list): list of names.
return_input (bool): same as :attr:`PredictorBase.return_input`. return_input (bool): same as :attr:`PredictorBase.return_input`.
sess (tf.Session): the session this predictor runs in. If None, sess (tf.Session): the session this predictor runs in. If None,
will use the default session. will use the default session at the first call.
""" """
self.return_input = return_input self.return_input = return_input
self.input_tensors = input_tensors self.input_tensors = input_tensors
...@@ -118,10 +118,8 @@ class OnlinePredictor(PredictorBase): ...@@ -118,10 +118,8 @@ class OnlinePredictor(PredictorBase):
"{} != {}".format(len(dp), len(self.input_tensors)) "{} != {}".format(len(dp), len(self.input_tensors))
feed = dict(zip(self.input_tensors, dp)) feed = dict(zip(self.input_tensors, dp))
if self.sess is None: if self.sess is None:
sess = tf.get_default_session() self.sess = tf.get_default_session()
else: output = self.sess.run(self.output_tensors, feed_dict=feed)
sess = self.sess
output = sess.run(self.output_tensors, feed_dict=feed)
return output return output
......
...@@ -62,13 +62,7 @@ class Trainer(object): ...@@ -62,13 +62,7 @@ class Trainer(object):
self.local_step = -1 self.local_step = -1
self._callbacks = [] self._callbacks = []
self.register_callback(MaintainStepCounter())
for cb in config.callbacks:
self.register_callback(cb)
self.monitors = [] self.monitors = []
for m in config.monitors:
self.register_monitor(m)
def register_callback(self, cb): def register_callback(self, cb):
""" """
...@@ -91,7 +85,7 @@ class Trainer(object): ...@@ -91,7 +85,7 @@ class Trainer(object):
assert not isinstance(self.monitors, Monitors), \ assert not isinstance(self.monitors, Monitors), \
"Cannot register more monitors after trainer was setup!" "Cannot register more monitors after trainer was setup!"
if not self.is_chief and mon.chief_only: if not self.is_chief and mon.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(mon))) logger.warn("Monitor {} is chief-only, skipped.".format(str(mon)))
else: else:
self.monitors.append(mon) self.monitors.append(mon)
self.register_callback(mon) self.register_callback(mon)
...@@ -115,10 +109,14 @@ class Trainer(object): ...@@ -115,10 +109,14 @@ class Trainer(object):
""" """
self._setup() # subclass will setup the graph self._setup() # subclass will setup the graph
self.register_callback(MaintainStepCounter())
for cb in self.config.callbacks:
self.register_callback(cb)
for m in self.config.monitors:
self.register_monitor(m)
self.monitors = Monitors(self.monitors) self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors) self.register_callback(self.monitors)
# TODO cache per graph, avoid describing all towers
describe_model() describe_model()
# some final operations that might modify the graph # some final operations that might modify the graph
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment