Commit ba4e3178 authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] Swap trainer directory. change two examples.

parent 9268bc8c
......@@ -8,6 +8,14 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/10/21]
tensorpack is gradually switching to a new Trainer API.
Compatibility is kept in most ways but not guaranteed.
To switch to new API, the easiest way is to:
1. `export TENSORPACK_TRAIN_API=v2` (will be default in the future).
2. Replace `SomeTrainer(config, ...).train()` with `launch_train_with_config(config, SomeTrainer(...))`.
+ [2017/10/18]
`TrainConfig(predict_tower)` was deprecated. You can set the inference device directly when creating the `InferenceRunner` callback.
+ [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e).
......
......@@ -22,9 +22,11 @@ In other words, an "epoch" in tensorpack is the __default period to run callback
### Common Trainers
Most neural network training tasks are single-cost optimization.
Tensorpack provides some trainer implementations for such tasks.
These trainers will build the graph based on the given `ModelDesc`, and minimizes `ModelDesc.cost`.
<!--
-Most neural network training tasks are single-cost optimization.
-Tensorpack provides some trainer implementations for such tasks.
-These trainers will build the graph based on the given `ModelDesc`, and minimizes `ModelDesc.cost`.
-->
<!--
-To use trainers, pass a `TrainConfig` to configure them:
......@@ -49,7 +51,7 @@ These trainers will build the graph based on the given `ModelDesc`, and minimize
-in the [Input Pipeline](input-source.html) tutorial.
-You can set the InputSource instead, to customize this behavior.
-->
Trainers are being redesigned, so the recommended API will likely be changed soon.
Trainers are being redesigned, this page will be updated soon.
Existing multi-GPU trainers include the logic of data-parallel training.
You can enable them by just one line, and all the necessary logic to achieve the best performance was baked into the trainers already.
......
......@@ -2,12 +2,13 @@
# -*- coding: UTF-8 -*-
# File: cifar-convnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from tensorpack import *
import tensorflow as tf
import argparse
import numpy as np
import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
......@@ -151,8 +152,7 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = max(len(args.gpu.split(',')), 1)
if config.nr_tower <= 1:
QueueInputTrainer(config).train()
else:
SyncMultiGPUTrainerParameterServer(config).train()
nr_gpu = len(args.gpu.split(','))
trainer = QueueInputTrainer() if nr_gpu <= 1 \
else SyncMultiGPUTrainerParameterServer(list(range(nr_gpu)))
launch_train_with_config(config, trainer)
......@@ -12,6 +12,7 @@ MNIST ConvNet example.
about 0.6% validation error after 30 epochs.
"""
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
# Just import everything into current namespace
from tensorpack import *
from tensorpack.tfutils import summary
......@@ -142,4 +143,4 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load)
# SimpleTrainer is slow, this is just a demo.
# You can use QueueInputTrainer instead
SimpleTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
[flake8]
max-line-length = 120
ignore = F403,F401,F405,F841,E401
ignore = F403,F401,F405,F841,E401,E402
exclude = private,
FasterRCNN/utils
......@@ -18,9 +18,9 @@ if _HAS_TF:
# In development. Default to v1
if _os.environ.get('TENSORPACK_TRAIN_API', 'v1') == 'v2':
from tensorpack.trainv2 import *
else:
from tensorpack.train import *
else:
from tensorpack.trainv1 import *
from tensorpack.graph_builder import InputDesc, ModelDesc, ModelDescBase
from tensorpack.input_source import *
from tensorpack.predict import *
This diff is collapsed.
......@@ -7,11 +7,11 @@ import tensorflow as tf
from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInputWrapper, DummyConstantInput)
from ..train.config import TrainConfig
from ..trainv1.config import TrainConfig
from .base import SingleCostTrainer
from .trainers import SimpleTrainer, DistributedTrainerReplicated
__all__ = ['launch_train_with_config', 'TrainConfig', 'apply_default_prefetch']
__all__ = ['launch_train_with_config', 'apply_default_prefetch']
def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
......
......@@ -24,6 +24,7 @@ from .base import SingleCostTrainer
__all__ = ['SimpleTrainer',
'QueueInputTrainer',
'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
......@@ -68,6 +69,17 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
return []
def SyncMultiGPUTrainer(towers):
"""
Return a default multi-GPU trainer, if you don't care about the details.
It may not be the most efficient one for your task.
Args:
towers (list[int]): list of GPU ids.
"""
return SyncMultiGPUTrainerParameterServer(towers, ps_device='gpu')
class AsyncMultiGPUTrainer(SingleCostTrainer):
__doc__ = AsyncMultiGPUBuilder.__doc__
......
......@@ -19,7 +19,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__)
_SKIP = []
_SKIP = ['utility']
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: interface.py
__all__ = ['launch_train_with_config']
def launch_train_with_config(config, trainer):
from ..train.interface import launch_train_with_config as old_launch
old_launch(config, trainer)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utility.py
# for backwards-compatibility
from ..graph_builder.utils import ( # noqa
OverrideToLocalVariable,
override_to_local_variable, LeastLoadedDeviceSetter)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment