Commit f85c3003 authored by Yuxin Wu's avatar Yuxin Wu

update documentation

parent b785bf77
......@@ -359,6 +359,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'replace_get_variable',
'remap_get_variable',
'freeze_get_variable',
'Triggerable',
'ParamRestore']:
return True
if name in ['get_data', 'size', 'reset_state']:
......
## Write a callback
## Write a Callback
The places where each callback method gets called is demonstrated in this snippet:
......@@ -20,7 +20,7 @@ def main_loop():
callbacks.after_train()
```
### Explain the callback methods
### Explain the Callback Methods
You can override any of the following methods to define a new callback:
......
......@@ -17,8 +17,8 @@ To use trainers, pass a `TrainConfig` to configure them:
```python
config = TrainConfig(
model=MyModel()
dataflow=my_dataflow,
# data=my_inputsource, # alternatively, use a customized InputSource
dataflow=my_dataflow,
# data=my_inputsource, # alternatively, use a customized InputSource
callbacks=[...]
)
......@@ -45,4 +45,16 @@ would be multiplied by the number of GPUs.
### Custom Trainers
Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration.
For example, [GAN trainer](../examples/GAN/GAN.py) minimizes two cost functions alternatively.
The existing trainers implement the default logic, but you can implement them yourself by using the base `Trainer` class.
* Two ways to customize the graph:
1. Create the graph, add any tensors and ops before creating the trainer.
2. Subclass `Trainer` and override the `_setup()` method which will be called in `Trainer.__init__`.
* Two ways to customize the iteration:
1. Set `Trainer.train_op`. This op will be run by default.
2. Subclass `Trainer` and override the `run_step()` method.
There are several different [GAN trainers](../examples/GAN/GAN.py) for reference.
......@@ -13,7 +13,10 @@ __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
@six.add_metaclass(ABCMeta)
class Callback(object):
""" Base class for all callbacks.
""" Base class for all callbacks. See
`Write a Callback
<http://tensorpack.readthedocs.io/en/latest/tutorial/extend/callback.html>`_
for more detailed explanation of the callback methods.
Attributes:
epoch_num(int): the number of the current epoch.
......@@ -261,7 +264,8 @@ class CallbackFactory(Callback):
"""
Each lambda takes ``self`` as the only argument.
trigger_epoch was deprecated.
Note:
trigger_epoch was deprecated.
"""
self._cb_setup_graph = setup_graph
......
......@@ -37,8 +37,6 @@ class Trainer(object):
sess (tf.Session): the current session in use.
hooked_sess (tf.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Callbacks can use it for logging.
epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch.
"""
# step attr only available after before_train?
......@@ -64,6 +62,9 @@ class Trainer(object):
@property
def epoch_num(self):
"""
The number of epochs that have finished.
"""
if self._epoch_num is not None:
# has started training
return self._epoch_num
......
......@@ -18,10 +18,11 @@ from ..callbacks.graph import RunOp
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from .feedfree import FeedfreeTrainerBase
__all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter',
__all__ = ['MultiGPUTrainerBase', 'LeastLoadedDeviceSetter',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer']
'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
'SyncMultiGPUTrainer']
def _check_tf_version():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment