Commit 17a73a4c authored by Yuxin Wu's avatar Yuxin Wu

Update more examples to use Trainerv2

parent e63fe50f
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
import os import os
import tensorflow as tf import tensorflow as tf
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
...@@ -192,6 +193,6 @@ if __name__ == '__main__': ...@@ -192,6 +193,6 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) nr_tower = len(args.gpu.split(','))
assert config.nr_tower == NR_GPU assert nr_tower == NR_GPU
SyncMultiGPUTrainer(config).train() launch_train_with_config(config, SyncMultiGPUTrainer(NR_GPU))
...@@ -10,6 +10,7 @@ import os ...@@ -10,6 +10,7 @@ import os
import tensorflow as tf import tensorflow as tf
import multiprocessing import multiprocessing
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
...@@ -298,5 +299,4 @@ if __name__ == '__main__': ...@@ -298,5 +299,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.nr_tower = NR_GPU launch_train_with_config(config, SyncMultiGPUTrainer(NR_GPU))
SyncMultiGPUTrainer(config).train()
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import os import os
import argparse import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.gradproc import * from tensorpack.tfutils.gradproc import *
from tensorpack.tfutils import optimizer, summary from tensorpack.tfutils import optimizer, summary
...@@ -174,4 +175,4 @@ if __name__ == '__main__': ...@@ -174,4 +175,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import argparse import argparse
import os import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
...@@ -171,7 +172,7 @@ if __name__ == '__main__': ...@@ -171,7 +172,7 @@ if __name__ == '__main__':
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
], ],
max_epoch=400, max_epoch=400,
nr_tower=max(get_nr_gpu(), 1),
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
SyncMultiGPUTrainerParameterServer(config).train() nr_gpu = max(get_nr_gpu(), 1)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_gpu))
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: svhn-resnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import argparse
import numpy as np
import os
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu
import tensorflow as tf
"""
ResNet-110 for SVHN Digit Classification.
Reach 1.8% validation error after 70 epochs, with 2 TitanX. 2it/s.
You might need to adjust the learning rate schedule when running with 1 GPU.
"""
import imp
cifar_example = imp.load_source('cifar_example',
os.path.join(os.path.dirname(__file__), 'cifar10-resnet.py'))
Model = cifar_example.Model
BATCH_SIZE = 128
def get_data(train_or_test):
isTrain = train_or_test == 'train'
pp_mean = dataset.SVHNDigit.get_per_pixel_mean()
if isTrain:
d1 = dataset.SVHNDigit('train')
d2 = dataset.SVHNDigit('extra')
ds = RandomMixData([d1, d2])
else:
ds = dataset.SVHNDigit('test')
if isTrain:
augmentors = [
imgaug.CenterPaste((40, 40)),
imgaug.Brightness(10),
imgaug.Contrast((0.8, 1.2)),
imgaug.GaussianDeform( # this is slow. without it, can only reach 1.9% error
[(0.2, 0.2), (0.2, 0.8), (0.8, 0.8), (0.8, 0.2)],
(40, 40), 0.2, 3),
imgaug.RandomCrop((32, 32)),
imgaug.MapImage(lambda x: x - pp_mean),
]
else:
augmentors = [
imgaug.MapImage(lambda x: x - pp_mean)
]
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 5, 5)
return ds
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.auto_set_dir()
dataset_train = get_data('train')
dataset_test = get_data('test')
config = TrainConfig(
model=Model(n=18),
dataflow=dataset_train,
callbacks=[
ModelSaver(),
InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError()]),
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (28, 0.001), (50, 0.0001)])
],
nr_tower=max(get_nr_gpu(), 1),
session_init=SaverRestore(args.load) if args.load else None,
max_epoch=500,
)
SyncMultiGPUTrainerParameterServer(config).train()
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
import os import os
import multiprocessing import multiprocessing
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import * from tensorpack import *
...@@ -19,9 +20,10 @@ from tensorpack.tfutils.summary import * ...@@ -19,9 +20,10 @@ from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils import viz from tensorpack.utils import viz
from imagenet_resnet_utils import ( from imagenet_utils import (
fbresnet_augmentor, preresnet_basicblock, preresnet_group, fbresnet_augmentor, image_preprocess, compute_loss_and_error)
image_preprocess, compute_loss_and_error) from resnet_model import (
preresnet_basicblock, preresnet_group)
TOTAL_BATCH_SIZE = 256 TOTAL_BATCH_SIZE = 256
...@@ -90,10 +92,6 @@ def get_data(train_or_test): ...@@ -90,10 +92,6 @@ def get_data(train_or_test):
def get_config(): def get_config():
nr_gpu = get_nr_gpu()
global BATCH_SIZE
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
dataset_train = get_data('train') dataset_train = get_data('train')
dataset_val = get_data('val') dataset_val = get_data('val')
...@@ -111,7 +109,6 @@ def get_config(): ...@@ -111,7 +109,6 @@ def get_config():
], ],
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=105, max_epoch=105,
nr_tower=nr_gpu
) )
...@@ -163,6 +160,9 @@ if __name__ == '__main__': ...@@ -163,6 +160,9 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
if args.cam: if args.cam:
BATCH_SIZE = 128 # something that can run on one gpu BATCH_SIZE = 128 # something that can run on one gpu
viz_cam(args.load, args.data) viz_cam(args.load, args.data)
...@@ -172,4 +172,4 @@ if __name__ == '__main__': ...@@ -172,4 +172,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = get_model_loader(args.load) config.session_init = get_model_loader(args.load)
SyncMultiGPUTrainerParameterServer(config).train() launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_gpu))
../ResNet/imagenet_resnet_utils.py
\ No newline at end of file
../ResNet/imagenet_utils.py
\ No newline at end of file
../ResNet/resnet_model.py
\ No newline at end of file
...@@ -10,10 +10,11 @@ import cv2 ...@@ -10,10 +10,11 @@ import cv2
import tensorflow as tf import tensorflow as tf
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import logger, QueueInput, InputDesc, PlaceholderInput, TowerContext from tensorpack import logger, QueueInput, InputDesc, PlaceholderInput, TowerContext
from tensorpack.models import * from tensorpack.models import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.train import TrainConfig, SyncMultiGPUTrainerParameterServer from tensorpack.train import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
from tensorpack.tfutils import argscope, get_model_loader from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
...@@ -141,8 +142,7 @@ def get_data(name, batch): ...@@ -141,8 +142,7 @@ def get_data(name, batch):
args.data, name, batch, augmentors) args.data, name, batch, augmentors)
def get_config(model): def get_config(model, nr_tower):
nr_tower = max(get_nr_gpu(), 1)
batch = TOTAL_BATCH_SIZE // nr_tower batch = TOTAL_BATCH_SIZE // nr_tower
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch)) logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
...@@ -170,7 +170,6 @@ def get_config(model): ...@@ -170,7 +170,6 @@ def get_config(model):
callbacks=callbacks, callbacks=callbacks,
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=100, max_epoch=100,
nr_tower=nr_tower
) )
...@@ -205,5 +204,6 @@ if __name__ == '__main__': ...@@ -205,5 +204,6 @@ if __name__ == '__main__':
logger.set_logger_dir( logger.set_logger_dir(
os.path.join('train_log', 'shufflenet')) os.path.join('train_log', 'shufflenet'))
config = get_config(model) nr_tower = max(get_nr_gpu(), 1)
SyncMultiGPUTrainerParameterServer(config).train() config = get_config(model, nr_tower)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower))
...@@ -9,7 +9,7 @@ import argparse ...@@ -9,7 +9,7 @@ import argparse
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.slim as slim import tensorflow.contrib.slim as slim
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
...@@ -442,4 +442,4 @@ if __name__ == '__main__': ...@@ -442,4 +442,4 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
else: else:
SimpleTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -10,6 +10,7 @@ import os ...@@ -10,6 +10,7 @@ import os
import sys import sys
import argparse import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils import sesscreate, optimizer, summary from tensorpack.tfutils import sesscreate, optimizer, summary
...@@ -186,4 +187,4 @@ if __name__ == '__main__': ...@@ -186,4 +187,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment