Commit a867fa57 authored by Yuxin Wu's avatar Yuxin Wu

remove Trainer.train_with_config

parent bfad96d7
......@@ -50,7 +50,7 @@ These trainers will build the graph by itself, with the following arguments:
3. A function which takes input tensors and returns the cost.
4. A function which returns an optimizer.
See [SingleCostTrainer.setup_graph](http://localhost:8000/modules/train.html#tensorpack.train.SingleCostTrainer.setup_graph)
See [SingleCostTrainer.setup_graph](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.SingleCostTrainer.setup_graph)
for details.
Existing multi-GPU trainers include the logic of data-parallel training.
......
......@@ -5,10 +5,22 @@ Tensorpack trainers have an interface for maximum flexibility.
There are also interfaces built on top of trainers to simplify the use,
when you don't want to customize too much.
### Raw Trainer Interface
For general trainer, build the graph by yourself.
For single-cost trainer, build the graph by
[SingleCostTrainer.setup_graph](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.SingleCostTrainer.setup_graph).
Then, call
[Trainer.train()](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.Trainer.train)
or
[Trainer.train_with_defaults()](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.Trainer.train_with_defaults)
which applies some defaults options for normal use cases.
### With ModelDesc and TrainConfig
[SingleCost trainers](trainer.html#single-cost-trainers)
expects 4 arguments to build the graph: `InputDesc`, `InputSource`, get_cost function, and optimizer.
expects 4 arguments in `setup_graph`: `InputDesc`, `InputSource`, get_cost function, and an optimizer.
`ModelDesc` describes a model by packing 3 of them together into one object:
```python
......
......@@ -25,7 +25,7 @@ import tensorpack.trainv1 as old_train # noqa
from ..trainv1.base import StopTraining, TrainLoop
from ..trainv1.config import TrainConfig
__all__ = ['TrainConfig', 'Trainer']
__all__ = ['TrainConfig', 'Trainer', 'DEFAULT_MONITORS', 'DEFAULT_CALLBACKS']
def DEFAULT_CALLBACKS():
......@@ -243,28 +243,6 @@ class Trainer(object):
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
def train_with_config(self, config):
"""
An alias to simplify the use of `TrainConfig` with `Trainer`.
This method is literally the following:
.. code-block:: python
self.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
"""
if config.data or config.dataflow or config.model:
logger.warn(
"data/dataflow/model in TrainConfig will not be used "
"in `Trainer.train_with_config`")
logger.warn("To build the graph from config, use `launch_train_with_config`!")
self.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
def train_with_defaults(
self, callbacks=None, monitors=None,
session_creator=None, session_init=None,
......
......@@ -86,5 +86,7 @@ def launch_train_with_config(config, trainer):
trainer.setup_graph(
inputs_desc, input,
model._build_graph_get_cost, model.get_optimizer)
config.data = config.dataflow = config.model = None
trainer.train_with_config(config)
trainer.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment