Commit 403815b5 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent f2d06a64
## Write a Callback
In the main loop of the trainer,
the callbacks will be called in the order they are given in `TrainConfig`.
The time where each callback method gets called is demonstrated in this snippet:
The time where each callback method gets called is demonstrated in this snippet.
```python
def train(self):
# ... a predefined trainer may create graph for the model here ...
......@@ -21,6 +18,7 @@ def train(self):
callbacks.trigger_epoch()
callbacks.after_train()
```
Note that at each place, each callback will be called in the order they are given to the trainer.
### Explain the Callback Methods
......
......@@ -69,14 +69,14 @@ handle corner cases in noisy data, preprocess, etc.
`InputSource` is an abstract interface in tensorpack, to describe where the inputs come from and how they enter the graph.
For example,
1. Come from a DataFlow and been fed to the graph.
2. Come from a DataFlow and been prefetched on CPU by a TF queue.
3. Come from a DataFlow, prefetched on CPU by a TF queue, then prefetched on GPU by a TF StagingArea.
1. [FeedInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.FeedInput):
Come from a DataFlow and been fed to the graph.
2. [QueueInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.QueueInput):
Come from a DataFlow and been prefetched on CPU by a TF queue.
3. [StagingInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.StagingInput):
Come from some `InputSource`, then prefetched on GPU by a TF StagingArea.
4. Come from a DataFlow, and further processed by `tf.data.Dataset`.
5. Come from some TF native reading pipeline.
5. [TensorInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.TensorInput):
Come from some TF reading ops. (See the [PTB example](../../tensorpack/tree/master/examples/PennTreebank))
6. Come from some ZMQ pipe, where the load/preprocessing may happen on a different machine.
When you set `TrainConfig(dataflow=)`, tensorpack trainers automatically adds proper prefetching for you.
You can also use `TrainConfig(data=)` option to use a customized `InputSource`.
In case you want to use TF ops rather than a DataFlow, you can use `TensorInput` as the `InputSource`
(See the [PTB example](../../tensorpack/tree/master/examples/PennTreebank)).
......@@ -13,7 +13,7 @@ Here's a list of things you can do when your training is slow:
3. If the GPU utilization is low, it may be because of slow data, or some ops are on CPU. Also make sure GPUs are not locked in P8 state.
## Benchmark the components
1. Use `data=DummyConstantInput(shapes)` in `TrainConfig`,
1. Use `DummyConstantInput(shapes)` as the `InputSource`.
so that the iterations doesn't take any data from Python side but train on a constant tensor.
This will help find out the slow operations you're using in the graph.
2. Use `dataflow=FakeData(shapes, random=False)` to replace your original DataFlow by a constant DataFlow.
......
......@@ -10,12 +10,12 @@ This is how TensorFlow summaries eventually get logged/saved/printed:
1. __What to Log__: When you call `tf.summary.xxx` in your graph code, TensorFlow adds an op to
`tf.GraphKeys.SUMMARIES` collection (by default).
2. __When to Log__: A [MergeAllSummaries](../modules/callbacks.html#tensorpack.callbacks.MergeAllSummaries)
callback is enabled by default in `TrainConfig`.
2. __When to Log__: [MergeAllSummaries](../modules/callbacks.html#tensorpack.callbacks.MergeAllSummaries)
callback is in the [default callbacks](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.DEFAULT_CALLBACKS).
It runs ops in the `SUMMARIES` collection (by default) every epoch (by default),
and writes results to the monitors.
3. __Where to Log__:
Several monitor instances are enabled by default in [TrainConfig](../modules/train.html#tensorpack.train.TrainConfig):
Several monitors are [default monitors](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.DEFAULT_MONITORS).
* A [TFEventWriter](../modules/callbacks.html#tensorpack.callbacks.TFEventWriter)
writes things to an event file used by tensorboard.
* A [ScalarPrinter](../modules/callbacks.html#tensorpack.callbacks.ScalarPrinter)
......@@ -23,7 +23,7 @@ This is how TensorFlow summaries eventually get logged/saved/printed:
* A [JSONWriter](../modules/callbacks.html#tensorpack.callbacks.JSONWriter)
saves scalars to a JSON file.
All the "what, when, where" can be customized in either the graph or the `TrainConfig`.
All the "what, when, where" can be customized in either the graph or with the callbacks/monitors setting.
Since TF summaries are evaluated every epoch by default, if the content is data-dependent, the results
are likely to have too much variance. To address this issue, you can:
......
......@@ -10,7 +10,8 @@ from ..utils import logger
from ..utils.argtools import call_only_once
from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper']
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper',
'TowerTensorHandle', 'TowerTensorHandles']
_CurrentTowerContext = None
......@@ -156,6 +157,9 @@ class TowerFuncWrapper(object):
A wrapper around a function which builds one tower (one replicate of the model).
It keeps track of the name scope, variable scope and input/output tensors
each time the function is called.
:class:`TowerTrainer` needs this option to be set, so that
it knows how to build a predictor.
"""
def __init__(self, tower_fn, inputs_desc):
......@@ -189,6 +193,11 @@ class TowerFuncWrapper(object):
@property
def towers(self):
"""
Returns:
a :class:`TowerTensorHandles` object, that can
access the tower handles by either indices or names.
"""
return TowerTensorHandles(self._handles)
@property
......@@ -206,6 +215,13 @@ class TowerTensorHandles(object):
self._name_to_handle = {k.ns_name: k for k in handles}
def __getitem__(self, name_or_index):
"""
Args:
name_or_index (str or int):
Returns:
a :class:`TowerTensorHandle`.
"""
if isinstance(name_or_index, int):
return self._handles[name_or_index]
return self._name_to_handle[name_or_index]
......
......@@ -30,7 +30,9 @@ __all__ = ['TrainConfig', 'Trainer', 'DEFAULT_MONITORS', 'DEFAULT_CALLBACKS']
def DEFAULT_CALLBACKS():
"""
Return the default callbacks. They are:
Return the default callbacks,
which will be used in :class:`TrainConfig` and :meth:`Trainer.train_with_defaults`.
They are:
1. MovingAverageSummary()
2. ProgressBar()
......@@ -46,7 +48,9 @@ def DEFAULT_CALLBACKS():
def DEFAULT_MONITORS():
"""
Return the default monitors. They are:
Return the default monitors,
which will be used in :class:`TrainConfig` and :meth:`Trainer.train_with_defaults`.
They are:
1. TFEventWriter()
2. JSONWriter()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment