Commit f85c3003 authored by Yuxin Wu's avatar Yuxin Wu

update documentation

parent b785bf77
...@@ -359,6 +359,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -359,6 +359,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'replace_get_variable', 'replace_get_variable',
'remap_get_variable', 'remap_get_variable',
'freeze_get_variable', 'freeze_get_variable',
'Triggerable',
'ParamRestore']: 'ParamRestore']:
return True return True
if name in ['get_data', 'size', 'reset_state']: if name in ['get_data', 'size', 'reset_state']:
......
## Write a callback ## Write a Callback
The places where each callback method gets called is demonstrated in this snippet: The places where each callback method gets called is demonstrated in this snippet:
...@@ -20,7 +20,7 @@ def main_loop(): ...@@ -20,7 +20,7 @@ def main_loop():
callbacks.after_train() callbacks.after_train()
``` ```
### Explain the callback methods ### Explain the Callback Methods
You can override any of the following methods to define a new callback: You can override any of the following methods to define a new callback:
......
...@@ -17,8 +17,8 @@ To use trainers, pass a `TrainConfig` to configure them: ...@@ -17,8 +17,8 @@ To use trainers, pass a `TrainConfig` to configure them:
```python ```python
config = TrainConfig( config = TrainConfig(
model=MyModel() model=MyModel()
dataflow=my_dataflow, dataflow=my_dataflow,
# data=my_inputsource, # alternatively, use a customized InputSource # data=my_inputsource, # alternatively, use a customized InputSource
callbacks=[...] callbacks=[...]
) )
...@@ -45,4 +45,16 @@ would be multiplied by the number of GPUs. ...@@ -45,4 +45,16 @@ would be multiplied by the number of GPUs.
### Custom Trainers ### Custom Trainers
Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration. Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration.
For example, [GAN trainer](../examples/GAN/GAN.py) minimizes two cost functions alternatively. The existing trainers implement the default logic, but you can implement them yourself by using the base `Trainer` class.
* Two ways to customize the graph:
1. Create the graph, add any tensors and ops before creating the trainer.
2. Subclass `Trainer` and override the `_setup()` method which will be called in `Trainer.__init__`.
* Two ways to customize the iteration:
1. Set `Trainer.train_op`. This op will be run by default.
2. Subclass `Trainer` and override the `run_step()` method.
There are several different [GAN trainers](../examples/GAN/GAN.py) for reference.
...@@ -13,7 +13,10 @@ __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable'] ...@@ -13,7 +13,10 @@ __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Callback(object): class Callback(object):
""" Base class for all callbacks. """ Base class for all callbacks. See
`Write a Callback
<http://tensorpack.readthedocs.io/en/latest/tutorial/extend/callback.html>`_
for more detailed explanation of the callback methods.
Attributes: Attributes:
epoch_num(int): the number of the current epoch. epoch_num(int): the number of the current epoch.
...@@ -261,7 +264,8 @@ class CallbackFactory(Callback): ...@@ -261,7 +264,8 @@ class CallbackFactory(Callback):
""" """
Each lambda takes ``self`` as the only argument. Each lambda takes ``self`` as the only argument.
trigger_epoch was deprecated. Note:
trigger_epoch was deprecated.
""" """
self._cb_setup_graph = setup_graph self._cb_setup_graph = setup_graph
......
...@@ -37,8 +37,6 @@ class Trainer(object): ...@@ -37,8 +37,6 @@ class Trainer(object):
sess (tf.Session): the current session in use. sess (tf.Session): the current session in use.
hooked_sess (tf.MonitoredSession): the session with hooks. hooked_sess (tf.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Callbacks can use it for logging. monitors (Monitors): the monitors. Callbacks can use it for logging.
epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch. local_step (int): the number of steps that have finished in the current epoch.
""" """
# step attr only available after before_train? # step attr only available after before_train?
...@@ -64,6 +62,9 @@ class Trainer(object): ...@@ -64,6 +62,9 @@ class Trainer(object):
@property @property
def epoch_num(self): def epoch_num(self):
"""
The number of epochs that have finished.
"""
if self._epoch_num is not None: if self._epoch_num is not None:
# has started training # has started training
return self._epoch_num return self._epoch_num
......
...@@ -18,10 +18,11 @@ from ..callbacks.graph import RunOp ...@@ -18,10 +18,11 @@ from ..callbacks.graph import RunOp
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from .feedfree import FeedfreeTrainerBase from .feedfree import FeedfreeTrainerBase
__all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer', __all__ = ['MultiGPUTrainerBase', 'LeastLoadedDeviceSetter',
'AsyncMultiGPUTrainer', 'LeastLoadedDeviceSetter',
'SyncMultiGPUTrainerReplicated', 'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer'] 'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
'SyncMultiGPUTrainer']
def _check_tf_version(): def _check_tf_version():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment