Commit ba4e3178 authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] Swap trainer directory. change two examples.

parent 9268bc8c
...@@ -8,6 +8,14 @@ so you won't need to look at here very often. ...@@ -8,6 +8,14 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here. TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/10/21]
tensorpack is gradually switching to a new Trainer API.
Compatibility is kept in most ways but not guaranteed.
To switch to new API, the easiest way is to:
1. `export TENSORPACK_TRAIN_API=v2` (will be default in the future).
2. Replace `SomeTrainer(config, ...).train()` with `launch_train_with_config(config, SomeTrainer(...))`.
+ [2017/10/18] + [2017/10/18]
`TrainConfig(predict_tower)` was deprecated. You can set the inference device directly when creating the `InferenceRunner` callback. `TrainConfig(predict_tower)` was deprecated. You can set the inference device directly when creating the `InferenceRunner` callback.
+ [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e). + [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e).
......
...@@ -22,9 +22,11 @@ In other words, an "epoch" in tensorpack is the __default period to run callback ...@@ -22,9 +22,11 @@ In other words, an "epoch" in tensorpack is the __default period to run callback
### Common Trainers ### Common Trainers
Most neural network training tasks are single-cost optimization. <!--
Tensorpack provides some trainer implementations for such tasks. -Most neural network training tasks are single-cost optimization.
These trainers will build the graph based on the given `ModelDesc`, and minimizes `ModelDesc.cost`. -Tensorpack provides some trainer implementations for such tasks.
-These trainers will build the graph based on the given `ModelDesc`, and minimizes `ModelDesc.cost`.
-->
<!-- <!--
-To use trainers, pass a `TrainConfig` to configure them: -To use trainers, pass a `TrainConfig` to configure them:
...@@ -49,7 +51,7 @@ These trainers will build the graph based on the given `ModelDesc`, and minimize ...@@ -49,7 +51,7 @@ These trainers will build the graph based on the given `ModelDesc`, and minimize
-in the [Input Pipeline](input-source.html) tutorial. -in the [Input Pipeline](input-source.html) tutorial.
-You can set the InputSource instead, to customize this behavior. -You can set the InputSource instead, to customize this behavior.
--> -->
Trainers are being redesigned, so the recommended API will likely be changed soon. Trainers are being redesigned, this page will be updated soon.
Existing multi-GPU trainers include the logic of data-parallel training. Existing multi-GPU trainers include the logic of data-parallel training.
You can enable them by just one line, and all the necessary logic to achieve the best performance was baked into the trainers already. You can enable them by just one line, and all the necessary logic to achieve the best performance was baked into the trainers already.
......
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: cifar-convnet.py # File: cifar-convnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from tensorpack import *
import tensorflow as tf import tensorflow as tf
import argparse import argparse
import numpy as np import numpy as np
import os import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
...@@ -151,8 +152,7 @@ if __name__ == '__main__': ...@@ -151,8 +152,7 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.nr_tower = max(len(args.gpu.split(',')), 1) nr_gpu = len(args.gpu.split(','))
if config.nr_tower <= 1: trainer = QueueInputTrainer() if nr_gpu <= 1 \
QueueInputTrainer(config).train() else SyncMultiGPUTrainerParameterServer(list(range(nr_gpu)))
else: launch_train_with_config(config, trainer)
SyncMultiGPUTrainerParameterServer(config).train()
...@@ -12,6 +12,7 @@ MNIST ConvNet example. ...@@ -12,6 +12,7 @@ MNIST ConvNet example.
about 0.6% validation error after 30 epochs. about 0.6% validation error after 30 epochs.
""" """
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
# Just import everything into current namespace # Just import everything into current namespace
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import summary from tensorpack.tfutils import summary
...@@ -142,4 +143,4 @@ if __name__ == '__main__': ...@@ -142,4 +143,4 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
# SimpleTrainer is slow, this is just a demo. # SimpleTrainer is slow, this is just a demo.
# You can use QueueInputTrainer instead # You can use QueueInputTrainer instead
SimpleTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = F403,F401,F405,F841,E401 ignore = F403,F401,F405,F841,E401,E402
exclude = private, exclude = private,
FasterRCNN/utils FasterRCNN/utils
...@@ -18,9 +18,9 @@ if _HAS_TF: ...@@ -18,9 +18,9 @@ if _HAS_TF:
# In development. Default to v1 # In development. Default to v1
if _os.environ.get('TENSORPACK_TRAIN_API', 'v1') == 'v2': if _os.environ.get('TENSORPACK_TRAIN_API', 'v1') == 'v2':
from tensorpack.trainv2 import *
else:
from tensorpack.train import * from tensorpack.train import *
else:
from tensorpack.trainv1 import *
from tensorpack.graph_builder import InputDesc, ModelDesc, ModelDescBase from tensorpack.graph_builder import InputDesc, ModelDesc, ModelDescBase
from tensorpack.input_source import * from tensorpack.input_source import *
from tensorpack.predict import * from tensorpack.predict import *
This diff is collapsed.
...@@ -7,11 +7,11 @@ import tensorflow as tf ...@@ -7,11 +7,11 @@ import tensorflow as tf
from ..input_source import ( from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInputWrapper, DummyConstantInput) InputSource, FeedInput, QueueInput, StagingInputWrapper, DummyConstantInput)
from ..train.config import TrainConfig from ..trainv1.config import TrainConfig
from .base import SingleCostTrainer from .base import SingleCostTrainer
from .trainers import SimpleTrainer, DistributedTrainerReplicated from .trainers import SimpleTrainer, DistributedTrainerReplicated
__all__ = ['launch_train_with_config', 'TrainConfig', 'apply_default_prefetch'] __all__ = ['launch_train_with_config', 'apply_default_prefetch']
def apply_default_prefetch(input_source_or_dataflow, trainer, towers): def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
......
...@@ -24,6 +24,7 @@ from .base import SingleCostTrainer ...@@ -24,6 +24,7 @@ from .base import SingleCostTrainer
__all__ = ['SimpleTrainer', __all__ = ['SimpleTrainer',
'QueueInputTrainer', 'QueueInputTrainer',
'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerReplicated', 'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer', 'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer', 'AsyncMultiGPUTrainer',
...@@ -68,6 +69,17 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer): ...@@ -68,6 +69,17 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
return [] return []
def SyncMultiGPUTrainer(towers):
"""
Return a default multi-GPU trainer, if you don't care about the details.
It may not be the most efficient one for your task.
Args:
towers (list[int]): list of GPU ids.
"""
return SyncMultiGPUTrainerParameterServer(towers, ps_device='gpu')
class AsyncMultiGPUTrainer(SingleCostTrainer): class AsyncMultiGPUTrainer(SingleCostTrainer):
__doc__ = AsyncMultiGPUBuilder.__doc__ __doc__ = AsyncMultiGPUBuilder.__doc__
......
...@@ -19,7 +19,7 @@ def global_import(name): ...@@ -19,7 +19,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
_SKIP = [] _SKIP = ['utility']
for _, module_name, _ in iter_modules( for _, module_name, _ in iter_modules(
[_CURR_DIR]): [_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py') srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: interface.py
__all__ = ['launch_train_with_config']
def launch_train_with_config(config, trainer):
from ..train.interface import launch_train_with_config as old_launch
old_launch(config, trainer)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utility.py
# for backwards-compatibility
from ..graph_builder.utils import ( # noqa
OverrideToLocalVariable,
override_to_local_variable, LeastLoadedDeviceSetter)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment