Commit a867fa57 authored by Yuxin Wu's avatar Yuxin Wu

remove Trainer.train_with_config

parent bfad96d7
...@@ -50,7 +50,7 @@ These trainers will build the graph by itself, with the following arguments: ...@@ -50,7 +50,7 @@ These trainers will build the graph by itself, with the following arguments:
3. A function which takes input tensors and returns the cost. 3. A function which takes input tensors and returns the cost.
4. A function which returns an optimizer. 4. A function which returns an optimizer.
See [SingleCostTrainer.setup_graph](http://localhost:8000/modules/train.html#tensorpack.train.SingleCostTrainer.setup_graph) See [SingleCostTrainer.setup_graph](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.SingleCostTrainer.setup_graph)
for details. for details.
Existing multi-GPU trainers include the logic of data-parallel training. Existing multi-GPU trainers include the logic of data-parallel training.
......
...@@ -5,10 +5,22 @@ Tensorpack trainers have an interface for maximum flexibility. ...@@ -5,10 +5,22 @@ Tensorpack trainers have an interface for maximum flexibility.
There are also interfaces built on top of trainers to simplify the use, There are also interfaces built on top of trainers to simplify the use,
when you don't want to customize too much. when you don't want to customize too much.
### Raw Trainer Interface
For general trainer, build the graph by yourself.
For single-cost trainer, build the graph by
[SingleCostTrainer.setup_graph](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.SingleCostTrainer.setup_graph).
Then, call
[Trainer.train()](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.Trainer.train)
or
[Trainer.train_with_defaults()](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.Trainer.train_with_defaults)
which applies some defaults options for normal use cases.
### With ModelDesc and TrainConfig ### With ModelDesc and TrainConfig
[SingleCost trainers](trainer.html#single-cost-trainers) [SingleCost trainers](trainer.html#single-cost-trainers)
expects 4 arguments to build the graph: `InputDesc`, `InputSource`, get_cost function, and optimizer. expects 4 arguments in `setup_graph`: `InputDesc`, `InputSource`, get_cost function, and an optimizer.
`ModelDesc` describes a model by packing 3 of them together into one object: `ModelDesc` describes a model by packing 3 of them together into one object:
```python ```python
......
...@@ -25,7 +25,7 @@ import tensorpack.trainv1 as old_train # noqa ...@@ -25,7 +25,7 @@ import tensorpack.trainv1 as old_train # noqa
from ..trainv1.base import StopTraining, TrainLoop from ..trainv1.base import StopTraining, TrainLoop
from ..trainv1.config import TrainConfig from ..trainv1.config import TrainConfig
__all__ = ['TrainConfig', 'Trainer'] __all__ = ['TrainConfig', 'Trainer', 'DEFAULT_MONITORS', 'DEFAULT_CALLBACKS']
def DEFAULT_CALLBACKS(): def DEFAULT_CALLBACKS():
...@@ -243,28 +243,6 @@ class Trainer(object): ...@@ -243,28 +243,6 @@ class Trainer(object):
self.initialize(session_creator, session_init) self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch) self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
def train_with_config(self, config):
"""
An alias to simplify the use of `TrainConfig` with `Trainer`.
This method is literally the following:
.. code-block:: python
self.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
"""
if config.data or config.dataflow or config.model:
logger.warn(
"data/dataflow/model in TrainConfig will not be used "
"in `Trainer.train_with_config`")
logger.warn("To build the graph from config, use `launch_train_with_config`!")
self.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
def train_with_defaults( def train_with_defaults(
self, callbacks=None, monitors=None, self, callbacks=None, monitors=None,
session_creator=None, session_init=None, session_creator=None, session_init=None,
......
...@@ -86,5 +86,7 @@ def launch_train_with_config(config, trainer): ...@@ -86,5 +86,7 @@ def launch_train_with_config(config, trainer):
trainer.setup_graph( trainer.setup_graph(
inputs_desc, input, inputs_desc, input,
model._build_graph_get_cost, model.get_optimizer) model._build_graph_get_cost, model.get_optimizer)
config.data = config.dataflow = config.model = None trainer.train(
trainer.train_with_config(config) config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment