Commit 17a73a4c authored by Yuxin Wu's avatar Yuxin Wu

Update more examples to use Trainerv2

parent e63fe50f
......@@ -9,6 +9,7 @@ import numpy as np
import os
import tensorflow as tf
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
......@@ -192,6 +193,6 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
assert config.nr_tower == NR_GPU
SyncMultiGPUTrainer(config).train()
nr_tower = len(args.gpu.split(','))
assert nr_tower == NR_GPU
launch_train_with_config(config, SyncMultiGPUTrainer(NR_GPU))
......@@ -10,6 +10,7 @@ import os
import tensorflow as tf
import multiprocessing
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
......@@ -298,5 +299,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = NR_GPU
SyncMultiGPUTrainer(config).train()
launch_train_with_config(config, SyncMultiGPUTrainer(NR_GPU))
......@@ -7,6 +7,7 @@ import numpy as np
import os
import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils.gradproc import *
from tensorpack.tfutils import optimizer, summary
......@@ -174,4 +175,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -7,6 +7,7 @@ import numpy as np
import argparse
import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
......@@ -171,7 +172,7 @@ if __name__ == '__main__':
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
],
max_epoch=400,
nr_tower=max(get_nr_gpu(), 1),
session_init=SaverRestore(args.load) if args.load else None
)
SyncMultiGPUTrainerParameterServer(config).train()
nr_gpu = max(get_nr_gpu(), 1)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_gpu))
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: svhn-resnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import argparse
import numpy as np
import os
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu
import tensorflow as tf
"""
ResNet-110 for SVHN Digit Classification.
Reach 1.8% validation error after 70 epochs, with 2 TitanX. 2it/s.
You might need to adjust the learning rate schedule when running with 1 GPU.
"""
import imp
cifar_example = imp.load_source('cifar_example',
os.path.join(os.path.dirname(__file__), 'cifar10-resnet.py'))
Model = cifar_example.Model
BATCH_SIZE = 128
def get_data(train_or_test):
isTrain = train_or_test == 'train'
pp_mean = dataset.SVHNDigit.get_per_pixel_mean()
if isTrain:
d1 = dataset.SVHNDigit('train')
d2 = dataset.SVHNDigit('extra')
ds = RandomMixData([d1, d2])
else:
ds = dataset.SVHNDigit('test')
if isTrain:
augmentors = [
imgaug.CenterPaste((40, 40)),
imgaug.Brightness(10),
imgaug.Contrast((0.8, 1.2)),
imgaug.GaussianDeform( # this is slow. without it, can only reach 1.9% error
[(0.2, 0.2), (0.2, 0.8), (0.8, 0.8), (0.8, 0.2)],
(40, 40), 0.2, 3),
imgaug.RandomCrop((32, 32)),
imgaug.MapImage(lambda x: x - pp_mean),
]
else:
augmentors = [
imgaug.MapImage(lambda x: x - pp_mean)
]
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 5, 5)
return ds
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.auto_set_dir()
dataset_train = get_data('train')
dataset_test = get_data('test')
config = TrainConfig(
model=Model(n=18),
dataflow=dataset_train,
callbacks=[
ModelSaver(),
InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError()]),
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (28, 0.001), (50, 0.0001)])
],
nr_tower=max(get_nr_gpu(), 1),
session_init=SaverRestore(args.load) if args.load else None,
max_epoch=500,
)
SyncMultiGPUTrainerParameterServer(config).train()
......@@ -9,6 +9,7 @@ import numpy as np
import os
import multiprocessing
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import *
......@@ -19,9 +20,10 @@ from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils import viz
from imagenet_resnet_utils import (
fbresnet_augmentor, preresnet_basicblock, preresnet_group,
image_preprocess, compute_loss_and_error)
from imagenet_utils import (
fbresnet_augmentor, image_preprocess, compute_loss_and_error)
from resnet_model import (
preresnet_basicblock, preresnet_group)
TOTAL_BATCH_SIZE = 256
......@@ -90,10 +92,6 @@ def get_data(train_or_test):
def get_config():
nr_gpu = get_nr_gpu()
global BATCH_SIZE
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
dataset_train = get_data('train')
dataset_val = get_data('val')
......@@ -111,7 +109,6 @@ def get_config():
],
steps_per_epoch=5000,
max_epoch=105,
nr_tower=nr_gpu
)
......@@ -163,6 +160,9 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
if args.cam:
BATCH_SIZE = 128 # something that can run on one gpu
viz_cam(args.load, args.data)
......@@ -172,4 +172,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
SyncMultiGPUTrainerParameterServer(config).train()
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_gpu))
../ResNet/imagenet_resnet_utils.py
\ No newline at end of file
../ResNet/imagenet_utils.py
\ No newline at end of file
../ResNet/resnet_model.py
\ No newline at end of file
......@@ -10,10 +10,11 @@ import cv2
import tensorflow as tf
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import logger, QueueInput, InputDesc, PlaceholderInput, TowerContext
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.train import TrainConfig, SyncMultiGPUTrainerParameterServer
from tensorpack.train import *
from tensorpack.dataflow import imgaug
from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.tfutils.scope_utils import under_name_scope
......@@ -141,8 +142,7 @@ def get_data(name, batch):
args.data, name, batch, augmentors)
def get_config(model):
nr_tower = max(get_nr_gpu(), 1)
def get_config(model, nr_tower):
batch = TOTAL_BATCH_SIZE // nr_tower
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
......@@ -170,7 +170,6 @@ def get_config(model):
callbacks=callbacks,
steps_per_epoch=5000,
max_epoch=100,
nr_tower=nr_tower
)
......@@ -205,5 +204,6 @@ if __name__ == '__main__':
logger.set_logger_dir(
os.path.join('train_log', 'shufflenet'))
config = get_config(model)
SyncMultiGPUTrainerParameterServer(config).train()
nr_tower = max(get_nr_gpu(), 1)
config = get_config(model, nr_tower)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower))
......@@ -9,7 +9,7 @@ import argparse
import tensorflow as tf
import tensorflow.contrib.slim as slim
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
......@@ -442,4 +442,4 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
else:
SimpleTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -10,6 +10,7 @@ import os
import sys
import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import sesscreate, optimizer, summary
......@@ -186,4 +187,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment