Commit 0be707e9 authored by Yuxin Wu's avatar Yuxin Wu

make old GANs work with new trainer

parent 9a711e72
......@@ -8,22 +8,27 @@ If you want to do something different during training, first consider writing it
or write an issue to see if there is a better solution than creating new trainers.
If your task is fundamentally different from single-cost optimization, you may need to write a trainer.
Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration.
The existing common trainers all implement two things:
1. Setup the graph and input pipeline, using the given `TrainConfig`.
2. Minimize `model.cost` in each iteration.
But you can customize it by using the base `Trainer` class.
* To customize the graph:
Add any tensors and ops you like, either before creating the trainer or inside `Trainer.__init__`.
In this case you don't need to set model/data in `TrainConfig` any more.
* Two ways to customize the iteration:
1. Set `Trainer.train_op`. This op will be run by default.
2. Subclass `Trainer` and override the `run_step()` method. This way you can do something more than running an op.
There are several different [GAN trainers](../../examples/GAN/GAN.py) for reference.
The implementation of [SimpleTrainer](../../tensorpack/train/simple.py) may also be helpful.
Trainers are recently being redesigned, the they best wayt to customize the trainer will likely to change.
We leave the tutorial empty for now.
<!--
-Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration.
-The existing common trainers all implement two things:
-1. Setup the graph and input pipeline, using the given `TrainConfig`.
-2. Minimize `model.cost` in each iteration.
-
-But you can customize it by using the base `Trainer` class.
-
-* To customize the graph:
-
- Add any tensors and ops you like, either before creating the trainer or inside `Trainer.__init__`.
- In this case you don't need to set model/data in `TrainConfig` any more.
-
-* Two ways to customize the iteration:
-
- 1. Set `Trainer.train_op`. This op will be run by default.
- 2. Subclass `Trainer` and override the `run_step()` method. This way you can do something more than running an op.
-
-There are several different [GAN trainers](../../examples/GAN/GAN.py) for reference.
-The implementation of [SimpleTrainer](../../tensorpack/train/simple.py) may also be helpful.
-->
......@@ -26,27 +26,30 @@ Most neural network training tasks are single-cost optimization.
Tensorpack provides some trainer implementations for such tasks.
These trainers will build the graph based on the given `ModelDesc`, and minimizes `ModelDesc.cost`.
To use trainers, pass a `TrainConfig` to configure them:
```python
config = TrainConfig(
model=MyModel()
dataflow=my_dataflow,
# data=my_inputsource, # alternatively, use a customized InputSource
callbacks=[...]
)
# start training:
SomeTrainer(config, other_arguments).train()
# start multi-GPU training with synchronous update:
# SyncMultiGPUTrainerParameterServer(config).train()
```
When you set the DataFlow (rather than the InputSource) in the config,
tensorpack trainers automatically adopt certain prefetch mechanism, as mentioned
in the [Input Pipeline](input-source.html) tutorial.
You can set the InputSource instead, to customize this behavior.
<!--
-To use trainers, pass a `TrainConfig` to configure them:
-
-```python
-config = TrainConfig(
- model=MyModel()
- dataflow=my_dataflow,
- # data=my_inputsource, # alternatively, use a customized InputSource
- callbacks=[...]
- )
-
-# start training:
-SomeTrainer(config, other_arguments).train()
-
-# start multi-GPU training with synchronous update:
-# SyncMultiGPUTrainerParameterServer(config).train()
-```
-
-When you set the DataFlow (rather than the InputSource) in the config,
-tensorpack trainers automatically adopt certain prefetch mechanism, as mentioned
-in the [Input Pipeline](input-source.html) tutorial.
-You can set the InputSource instead, to customize this behavior.
-->
Trainers are being redesigned, so the recommended API will likely be changed soon.
Existing multi-GPU trainers include the logic of data-parallel training.
You can enable them by just one line, and all the necessary logic to achieve the best performance was baked into the trainers already.
......
......@@ -38,11 +38,41 @@ class Trainer(object):
is_chief = True
def __init__(self):
def __init__(self, config=None):
"""
config is only for compatibility reasons in case you're
using custom trainers with old-style API.
You should never use config.
"""
self._callbacks = []
self.loop = TrainLoop()
self._monitors = [] # Clarify the type. Don't change from list to monitors.
# Hacks!
if config is not None:
logger.warn("You're initializing new trainer with old trainer API!")
logger.warn("This could happen if you wrote a custom trainer before.")
logger.warn("It may work now through some hacks, but please switch to the new API!")
self._config = config
self.inputs_desc = config.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(
lambda *inputs: config.model.build_graph(inputs),
self.inputs_desc)
self._main_tower_vs_name = ""
def gp(input_names, output_names, tower=0):
return TowerTrainer.get_predictor(self, input_names, output_names, device=tower)
self.get_predictor = gp
old_train = self.train
def train():
return old_train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
self.train = train
def _register_callback(self, cb):
"""
Register a callback to the trainer.
......@@ -192,7 +222,16 @@ class Trainer(object):
if (len(args) > 0 and isinstance(args[0], old_train.TrainConfig)) \
or 'config' in kwargs:
name = cls.__name__
try:
old_trainer = getattr(old_train, name)
except AttributeError:
# custom trainer. has to live with it
return super(Trainer, cls).__new__(cls)
else:
logger.warn("You're creating trainers with old trainer API!")
logger.warn("Now it returns the old trainer for you, please switch to the new API!")
logger.warn("'SomeTrainer(config, ...).train()' should be equivalent to "
"'launch_train_with_config(config, SomeTrainer(...))' in the new API.")
return old_trainer(*args, **kwargs)
else:
return super(Trainer, cls).__new__(cls)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment