Commit 403815b5 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent f2d06a64
## Write a Callback ## Write a Callback
In the main loop of the trainer, The time where each callback method gets called is demonstrated in this snippet.
the callbacks will be called in the order they are given in `TrainConfig`.
The time where each callback method gets called is demonstrated in this snippet:
```python ```python
def train(self): def train(self):
# ... a predefined trainer may create graph for the model here ... # ... a predefined trainer may create graph for the model here ...
...@@ -21,6 +18,7 @@ def train(self): ...@@ -21,6 +18,7 @@ def train(self):
callbacks.trigger_epoch() callbacks.trigger_epoch()
callbacks.after_train() callbacks.after_train()
``` ```
Note that at each place, each callback will be called in the order they are given to the trainer.
### Explain the Callback Methods ### Explain the Callback Methods
......
...@@ -69,14 +69,14 @@ handle corner cases in noisy data, preprocess, etc. ...@@ -69,14 +69,14 @@ handle corner cases in noisy data, preprocess, etc.
`InputSource` is an abstract interface in tensorpack, to describe where the inputs come from and how they enter the graph. `InputSource` is an abstract interface in tensorpack, to describe where the inputs come from and how they enter the graph.
For example, For example,
1. Come from a DataFlow and been fed to the graph. 1. [FeedInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.FeedInput):
2. Come from a DataFlow and been prefetched on CPU by a TF queue. Come from a DataFlow and been fed to the graph.
3. Come from a DataFlow, prefetched on CPU by a TF queue, then prefetched on GPU by a TF StagingArea. 2. [QueueInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.QueueInput):
Come from a DataFlow and been prefetched on CPU by a TF queue.
3. [StagingInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.StagingInput):
Come from some `InputSource`, then prefetched on GPU by a TF StagingArea.
4. Come from a DataFlow, and further processed by `tf.data.Dataset`. 4. Come from a DataFlow, and further processed by `tf.data.Dataset`.
5. Come from some TF native reading pipeline. 5. [TensorInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.TensorInput):
Come from some TF reading ops. (See the [PTB example](../../tensorpack/tree/master/examples/PennTreebank))
6. Come from some ZMQ pipe, where the load/preprocessing may happen on a different machine. 6. Come from some ZMQ pipe, where the load/preprocessing may happen on a different machine.
When you set `TrainConfig(dataflow=)`, tensorpack trainers automatically adds proper prefetching for you.
You can also use `TrainConfig(data=)` option to use a customized `InputSource`.
In case you want to use TF ops rather than a DataFlow, you can use `TensorInput` as the `InputSource`
(See the [PTB example](../../tensorpack/tree/master/examples/PennTreebank)).
...@@ -13,7 +13,7 @@ Here's a list of things you can do when your training is slow: ...@@ -13,7 +13,7 @@ Here's a list of things you can do when your training is slow:
3. If the GPU utilization is low, it may be because of slow data, or some ops are on CPU. Also make sure GPUs are not locked in P8 state. 3. If the GPU utilization is low, it may be because of slow data, or some ops are on CPU. Also make sure GPUs are not locked in P8 state.
## Benchmark the components ## Benchmark the components
1. Use `data=DummyConstantInput(shapes)` in `TrainConfig`, 1. Use `DummyConstantInput(shapes)` as the `InputSource`.
so that the iterations doesn't take any data from Python side but train on a constant tensor. so that the iterations doesn't take any data from Python side but train on a constant tensor.
This will help find out the slow operations you're using in the graph. This will help find out the slow operations you're using in the graph.
2. Use `dataflow=FakeData(shapes, random=False)` to replace your original DataFlow by a constant DataFlow. 2. Use `dataflow=FakeData(shapes, random=False)` to replace your original DataFlow by a constant DataFlow.
......
...@@ -10,12 +10,12 @@ This is how TensorFlow summaries eventually get logged/saved/printed: ...@@ -10,12 +10,12 @@ This is how TensorFlow summaries eventually get logged/saved/printed:
1. __What to Log__: When you call `tf.summary.xxx` in your graph code, TensorFlow adds an op to 1. __What to Log__: When you call `tf.summary.xxx` in your graph code, TensorFlow adds an op to
`tf.GraphKeys.SUMMARIES` collection (by default). `tf.GraphKeys.SUMMARIES` collection (by default).
2. __When to Log__: A [MergeAllSummaries](../modules/callbacks.html#tensorpack.callbacks.MergeAllSummaries) 2. __When to Log__: [MergeAllSummaries](../modules/callbacks.html#tensorpack.callbacks.MergeAllSummaries)
callback is enabled by default in `TrainConfig`. callback is in the [default callbacks](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.DEFAULT_CALLBACKS).
It runs ops in the `SUMMARIES` collection (by default) every epoch (by default), It runs ops in the `SUMMARIES` collection (by default) every epoch (by default),
and writes results to the monitors. and writes results to the monitors.
3. __Where to Log__: 3. __Where to Log__:
Several monitor instances are enabled by default in [TrainConfig](../modules/train.html#tensorpack.train.TrainConfig): Several monitors are [default monitors](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.DEFAULT_MONITORS).
* A [TFEventWriter](../modules/callbacks.html#tensorpack.callbacks.TFEventWriter) * A [TFEventWriter](../modules/callbacks.html#tensorpack.callbacks.TFEventWriter)
writes things to an event file used by tensorboard. writes things to an event file used by tensorboard.
* A [ScalarPrinter](../modules/callbacks.html#tensorpack.callbacks.ScalarPrinter) * A [ScalarPrinter](../modules/callbacks.html#tensorpack.callbacks.ScalarPrinter)
...@@ -23,7 +23,7 @@ This is how TensorFlow summaries eventually get logged/saved/printed: ...@@ -23,7 +23,7 @@ This is how TensorFlow summaries eventually get logged/saved/printed:
* A [JSONWriter](../modules/callbacks.html#tensorpack.callbacks.JSONWriter) * A [JSONWriter](../modules/callbacks.html#tensorpack.callbacks.JSONWriter)
saves scalars to a JSON file. saves scalars to a JSON file.
All the "what, when, where" can be customized in either the graph or the `TrainConfig`. All the "what, when, where" can be customized in either the graph or with the callbacks/monitors setting.
Since TF summaries are evaluated every epoch by default, if the content is data-dependent, the results Since TF summaries are evaluated every epoch by default, if the content is data-dependent, the results
are likely to have too much variance. To address this issue, you can: are likely to have too much variance. To address this issue, you can:
......
...@@ -10,7 +10,8 @@ from ..utils import logger ...@@ -10,7 +10,8 @@ from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper'] __all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper',
'TowerTensorHandle', 'TowerTensorHandles']
_CurrentTowerContext = None _CurrentTowerContext = None
...@@ -156,6 +157,9 @@ class TowerFuncWrapper(object): ...@@ -156,6 +157,9 @@ class TowerFuncWrapper(object):
A wrapper around a function which builds one tower (one replicate of the model). A wrapper around a function which builds one tower (one replicate of the model).
It keeps track of the name scope, variable scope and input/output tensors It keeps track of the name scope, variable scope and input/output tensors
each time the function is called. each time the function is called.
:class:`TowerTrainer` needs this option to be set, so that
it knows how to build a predictor.
""" """
def __init__(self, tower_fn, inputs_desc): def __init__(self, tower_fn, inputs_desc):
...@@ -189,6 +193,11 @@ class TowerFuncWrapper(object): ...@@ -189,6 +193,11 @@ class TowerFuncWrapper(object):
@property @property
def towers(self): def towers(self):
"""
Returns:
a :class:`TowerTensorHandles` object, that can
access the tower handles by either indices or names.
"""
return TowerTensorHandles(self._handles) return TowerTensorHandles(self._handles)
@property @property
...@@ -206,6 +215,13 @@ class TowerTensorHandles(object): ...@@ -206,6 +215,13 @@ class TowerTensorHandles(object):
self._name_to_handle = {k.ns_name: k for k in handles} self._name_to_handle = {k.ns_name: k for k in handles}
def __getitem__(self, name_or_index): def __getitem__(self, name_or_index):
"""
Args:
name_or_index (str or int):
Returns:
a :class:`TowerTensorHandle`.
"""
if isinstance(name_or_index, int): if isinstance(name_or_index, int):
return self._handles[name_or_index] return self._handles[name_or_index]
return self._name_to_handle[name_or_index] return self._name_to_handle[name_or_index]
......
...@@ -30,7 +30,9 @@ __all__ = ['TrainConfig', 'Trainer', 'DEFAULT_MONITORS', 'DEFAULT_CALLBACKS'] ...@@ -30,7 +30,9 @@ __all__ = ['TrainConfig', 'Trainer', 'DEFAULT_MONITORS', 'DEFAULT_CALLBACKS']
def DEFAULT_CALLBACKS(): def DEFAULT_CALLBACKS():
""" """
Return the default callbacks. They are: Return the default callbacks,
which will be used in :class:`TrainConfig` and :meth:`Trainer.train_with_defaults`.
They are:
1. MovingAverageSummary() 1. MovingAverageSummary()
2. ProgressBar() 2. ProgressBar()
...@@ -46,7 +48,9 @@ def DEFAULT_CALLBACKS(): ...@@ -46,7 +48,9 @@ def DEFAULT_CALLBACKS():
def DEFAULT_MONITORS(): def DEFAULT_MONITORS():
""" """
Return the default monitors. They are: Return the default monitors,
which will be used in :class:`TrainConfig` and :meth:`Trainer.train_with_defaults`.
They are:
1. TFEventWriter() 1. TFEventWriter()
2. JSONWriter() 2. JSONWriter()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment