Commit dc709e94 authored by Yuxin Wu's avatar Yuxin Wu

update some examples

parent 0c65c338
...@@ -18,6 +18,7 @@ import tensorflow as tf ...@@ -18,6 +18,7 @@ import tensorflow as tf
import six import six
from six.moves import queue from six.moves import queue
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.utils.serialize import * from tensorpack.utils.serialize import *
...@@ -303,5 +304,5 @@ if __name__ == '__main__': ...@@ -303,5 +304,5 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = get_model_loader(args.load) config.session_init = get_model_loader(args.load)
trainer = QueueInputTrainer if config.nr_tower == 1 else AsyncMultiGPUTrainer trainer = SimpleTrainer() if config.nr_tower == 1 else AsyncMultiGPUTrainer(config.tower)
trainer(config).train() launch_train_with_config(config, trainer)
...@@ -12,6 +12,7 @@ import operator ...@@ -12,6 +12,7 @@ import operator
import six import six
from six.moves import map, range from six.moves import map, range
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.gradproc import SummaryGradient, GlobalNormClip from tensorpack.tfutils.gradproc import SummaryGradient, GlobalNormClip
from tensorpack.utils.globvars import globalns as param from tensorpack.utils.globvars import globalns as param
...@@ -94,7 +95,7 @@ def get_data(path, isTrain, stat_file): ...@@ -94,7 +95,7 @@ def get_data(path, isTrain, stat_file):
def get_config(ds_train, ds_test): def get_config(ds_train, ds_test):
return TrainConfig( return TrainConfig(
dataflow=ds_train, data=QueueInput(ds_train),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
StatMonitorParamSetter('learning_rate', 'error', StatMonitorParamSetter('learning_rate', 'error',
...@@ -128,4 +129,4 @@ if __name__ == '__main__': ...@@ -128,4 +129,4 @@ if __name__ == '__main__':
config = get_config(ds_train, ds_test) config = get_config(ds_train, ds_test)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -12,6 +12,7 @@ import operator ...@@ -12,6 +12,7 @@ import operator
import six import six
from six.moves import map, range from six.moves import map, range
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import symbolic_functions, summary, optimizer from tensorpack.tfutils import symbolic_functions, summary, optimizer
from tensorpack.tfutils.gradproc import GlobalNormClip from tensorpack.tfutils.gradproc import GlobalNormClip
...@@ -116,7 +117,7 @@ def get_config(): ...@@ -116,7 +117,7 @@ def get_config():
ds = BatchData(ds, param.batch_size) ds = BatchData(ds, param.batch_size)
return TrainConfig( return TrainConfig(
dataflow=ds, data=QueueInput(ds),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)]) ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)])
...@@ -190,4 +191,4 @@ if __name__ == '__main__': ...@@ -190,4 +191,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -16,6 +16,7 @@ import multiprocessing ...@@ -16,6 +16,7 @@ import multiprocessing
import threading import threading
from collections import deque from collections import deque
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
import tensorflow as tf import tensorflow as tf
...@@ -105,7 +106,7 @@ def get_config(): ...@@ -105,7 +106,7 @@ def get_config():
) )
return TrainConfig( return TrainConfig(
dataflow=expreplay, data=QueueInput(expreplay),
model=Model(), model=Model(),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
...@@ -166,4 +167,4 @@ if __name__ == '__main__': ...@@ -166,4 +167,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = get_model_loader(args.load) config.session_init = get_model_loader(args.load)
QueueInputTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -8,6 +8,7 @@ import os ...@@ -8,6 +8,7 @@ import os
import sys import sys
import argparse import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
import tensorflow as tf import tensorflow as tf
...@@ -65,4 +66,4 @@ if __name__ == '__main__': ...@@ -65,4 +66,4 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config() config = get_config()
QueueInputTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import os import os
import imp import imp
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
...@@ -56,4 +57,4 @@ if __name__ == '__main__': ...@@ -56,4 +57,4 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
config = get_config(args.prob) config = get_config(args.prob)
QueueInputTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -11,6 +11,7 @@ import multiprocessing ...@@ -11,6 +11,7 @@ import multiprocessing
import os import os
import sys import sys
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
...@@ -322,4 +323,4 @@ if __name__ == '__main__': ...@@ -322,4 +323,4 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.nr_tower = nr_tower config.nr_tower = nr_tower
SyncMultiGPUTrainer(config).train() launch_train_with_config(configi, SyncMultiGPUTrainer(list(range(nr_tower))))
...@@ -7,6 +7,7 @@ import argparse ...@@ -7,6 +7,7 @@ import argparse
import numpy as np import numpy as np
import os import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
...@@ -163,7 +164,7 @@ def get_config(): ...@@ -163,7 +164,7 @@ def get_config():
data_test = BatchData(data_test, 128, remainder=True) data_test = BatchData(data_test, 128, remainder=True)
return TrainConfig( return TrainConfig(
dataflow=data_train, data=QueueInput(data_train),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
InferenceRunner(data_test, InferenceRunner(data_test,
...@@ -183,4 +184,4 @@ if __name__ == '__main__': ...@@ -183,4 +184,4 @@ if __name__ == '__main__':
BITW, BITA, BITG = map(int, args.dorefa.split(',')) BITW, BITA, BITG = map(int, args.dorefa.split(','))
config = get_config() config = get_config()
QueueInputTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -6,10 +6,12 @@ import argparse ...@@ -6,10 +6,12 @@ import argparse
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import cv2 import cv2
import os
from scipy.signal import convolve2d from scipy.signal import convolve2d
from six.moves import range, zip from six.moves import range, zip
import multiprocessing import multiprocessing
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
...@@ -262,5 +264,4 @@ if __name__ == '__main__': ...@@ -262,5 +264,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.nr_tower = NR_GPU launch_train_with_config(config, SyncMultiGPUTrainer(list(range(NR_GPU))))
SyncMultiGPUTrainer(config).train()
...@@ -13,6 +13,7 @@ import numpy as np ...@@ -13,6 +13,7 @@ import numpy as np
import json import json
import tensorflow as tf import tensorflow as tf
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
...@@ -300,6 +301,6 @@ if __name__ == '__main__': ...@@ -300,6 +301,6 @@ if __name__ == '__main__':
steps_per_epoch=stepnum, steps_per_epoch=stepnum,
max_epoch=205000 * factor // stepnum, max_epoch=205000 * factor // stepnum,
session_init=get_model_loader(args.load) if args.load else None, session_init=get_model_loader(args.load) if args.load else None,
nr_tower=get_nr_gpu()
) )
SyncMultiGPUTrainerReplicated(cfg, gpu_prefetch=False).train() trainer = SyncMultiGPUTrainerReplicated(range(len(get_nr_gpu())))
launch_train_with_config(cfg, trainer)
...@@ -11,6 +11,7 @@ from six.moves import zip ...@@ -11,6 +11,7 @@ from six.moves import zip
import os import os
import sys import sys
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
...@@ -231,5 +232,6 @@ if __name__ == '__main__': ...@@ -231,5 +232,6 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = get_model_loader(args.load) config.session_init = get_model_loader(args.load)
config.nr_tower = max(get_nr_gpu(), 1) launch_train_with_config(
SyncMultiGPUTrainer(config).train() config,
SyncMultiGPUTrainer(range(max(get_nr_gpu(), 1))))
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import os import os
import argparse import argparse
import tensorflow as tf import tensorflow as tf
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
""" """
...@@ -51,7 +52,7 @@ def get_config(): ...@@ -51,7 +52,7 @@ def get_config():
return TrainConfig( return TrainConfig(
model=Model(), model=Model(),
dataflow=ds_train, data=QueueInput(ds_train),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
InferenceRunner(ds_test, [ScalarStats('total_costs')]), InferenceRunner(ds_test, [ScalarStats('total_costs')]),
...@@ -77,4 +78,4 @@ if __name__ == '__main__': ...@@ -77,4 +78,4 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
SyncMultiGPUTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -14,6 +14,7 @@ the only differences are: ...@@ -14,6 +14,7 @@ the only differences are:
2. use slim names to summarize weights 2. use slim names to summarize weights
""" """
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
import tensorflow as tf import tensorflow as tf
...@@ -101,4 +102,4 @@ if __name__ == '__main__': ...@@ -101,4 +102,4 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config() config = get_config()
SimpleTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -11,6 +11,7 @@ import argparse ...@@ -11,6 +11,7 @@ import argparse
MNIST ConvNet example with weights/activations visualization. MNIST ConvNet example with weights/activations visualization.
""" """
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
import tensorflow as tf import tensorflow as tf
...@@ -161,4 +162,4 @@ if __name__ == '__main__': ...@@ -161,4 +162,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -7,6 +7,7 @@ import argparse ...@@ -7,6 +7,7 @@ import argparse
import numpy as np import numpy as np
import os import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import prediction_incorrect from tensorpack.tfutils.symbolic_functions import prediction_incorrect
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
...@@ -99,7 +100,7 @@ def get_config(): ...@@ -99,7 +100,7 @@ def get_config():
return TrainConfig( return TrainConfig(
model=Model(), model=Model(),
dataflow=data_train, data=QueueInput(data_train),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
InferenceRunner(data_test, InferenceRunner(data_test,
...@@ -125,4 +126,4 @@ if __name__ == '__main__': ...@@ -125,4 +126,4 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
...@@ -43,12 +43,12 @@ def apply_default_prefetch(input_source_or_dataflow, trainer, towers): ...@@ -43,12 +43,12 @@ def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
def launch_train_with_config(config, trainer): def launch_train_with_config(config, trainer):
""" """
Train with a :class:`TrainConfig` and a new version of :class:`Trainer`, to Train with a :class:`TrainConfig` and a :class:`Trainer`, to
mimic the old training interface. mimic the old training interface.
Args: Args:
config (TrainConfig): config (TrainConfig):
trainer (Trainer): an instance of the new trainer trainer (Trainer): an instance of a SingleCostTrainer
Examples: Examples:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment