Commit dc709e94 authored by Yuxin Wu's avatar Yuxin Wu

update some examples

parent 0c65c338
......@@ -18,6 +18,7 @@ import tensorflow as tf
import six
from six.moves import queue
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.utils.concurrency import *
from tensorpack.utils.serialize import *
......@@ -303,5 +304,5 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
trainer = QueueInputTrainer if config.nr_tower == 1 else AsyncMultiGPUTrainer
trainer(config).train()
trainer = SimpleTrainer() if config.nr_tower == 1 else AsyncMultiGPUTrainer(config.tower)
launch_train_with_config(config, trainer)
......@@ -12,6 +12,7 @@ import operator
import six
from six.moves import map, range
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils.gradproc import SummaryGradient, GlobalNormClip
from tensorpack.utils.globvars import globalns as param
......@@ -94,7 +95,7 @@ def get_data(path, isTrain, stat_file):
def get_config(ds_train, ds_test):
return TrainConfig(
dataflow=ds_train,
data=QueueInput(ds_train),
callbacks=[
ModelSaver(),
StatMonitorParamSetter('learning_rate', 'error',
......@@ -128,4 +129,4 @@ if __name__ == '__main__':
config = get_config(ds_train, ds_test)
if args.load:
config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -12,6 +12,7 @@ import operator
import six
from six.moves import map, range
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils import symbolic_functions, summary, optimizer
from tensorpack.tfutils.gradproc import GlobalNormClip
......@@ -116,7 +117,7 @@ def get_config():
ds = BatchData(ds, param.batch_size)
return TrainConfig(
dataflow=ds,
data=QueueInput(ds),
callbacks=[
ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)])
......@@ -190,4 +191,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -16,6 +16,7 @@ import multiprocessing
import threading
from collections import deque
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.utils.concurrency import *
import tensorflow as tf
......@@ -105,7 +106,7 @@ def get_config():
)
return TrainConfig(
dataflow=expreplay,
data=QueueInput(expreplay),
model=Model(),
callbacks=[
ModelSaver(),
......@@ -166,4 +167,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
QueueInputTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -8,6 +8,7 @@ import os
import sys
import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.dataflow import dataset
import tensorflow as tf
......@@ -65,4 +66,4 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
QueueInputTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -8,6 +8,7 @@ import numpy as np
import os
import imp
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.dataflow import dataset
......@@ -56,4 +57,4 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
config = get_config(args.prob)
QueueInputTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -11,6 +11,7 @@ import multiprocessing
import os
import sys
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
......@@ -322,4 +323,4 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = nr_tower
SyncMultiGPUTrainer(config).train()
launch_train_with_config(configi, SyncMultiGPUTrainer(list(range(nr_tower))))
......@@ -7,6 +7,7 @@ import argparse
import numpy as np
import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
......@@ -163,7 +164,7 @@ def get_config():
data_test = BatchData(data_test, 128, remainder=True)
return TrainConfig(
dataflow=data_train,
data=QueueInput(data_train),
callbacks=[
ModelSaver(),
InferenceRunner(data_test,
......@@ -183,4 +184,4 @@ if __name__ == '__main__':
BITW, BITA, BITG = map(int, args.dorefa.split(','))
config = get_config()
QueueInputTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -6,10 +6,12 @@ import argparse
import numpy as np
import tensorflow as tf
import cv2
import os
from scipy.signal import convolve2d
from six.moves import range, zip
import multiprocessing
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.utils import logger
from tensorpack.utils.viz import *
......@@ -262,5 +264,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = NR_GPU
SyncMultiGPUTrainer(config).train()
launch_train_with_config(config, SyncMultiGPUTrainer(list(range(NR_GPU))))
......@@ -13,6 +13,7 @@ import numpy as np
import json
import tensorflow as tf
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
......@@ -300,6 +301,6 @@ if __name__ == '__main__':
steps_per_epoch=stepnum,
max_epoch=205000 * factor // stepnum,
session_init=get_model_loader(args.load) if args.load else None,
nr_tower=get_nr_gpu()
)
SyncMultiGPUTrainerReplicated(cfg, gpu_prefetch=False).train()
trainer = SyncMultiGPUTrainerReplicated(range(len(get_nr_gpu())))
launch_train_with_config(cfg, trainer)
......@@ -11,6 +11,7 @@ from six.moves import zip
import os
import sys
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.dataflow import dataset
......@@ -231,5 +232,6 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
config.nr_tower = max(get_nr_gpu(), 1)
SyncMultiGPUTrainer(config).train()
launch_train_with_config(
config,
SyncMultiGPUTrainer(range(max(get_nr_gpu(), 1))))
......@@ -5,6 +5,7 @@
import os
import argparse
import tensorflow as tf
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
"""
......@@ -51,7 +52,7 @@ def get_config():
return TrainConfig(
model=Model(),
dataflow=ds_train,
data=QueueInput(ds_train),
callbacks=[
ModelSaver(),
InferenceRunner(ds_test, [ScalarStats('total_costs')]),
......@@ -77,4 +78,4 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
SyncMultiGPUTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -14,6 +14,7 @@ the only differences are:
2. use slim names to summarize weights
"""
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.dataflow import dataset
import tensorflow as tf
......@@ -101,4 +102,4 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
SimpleTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -11,6 +11,7 @@ import argparse
MNIST ConvNet example with weights/activations visualization.
"""
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.dataflow import dataset
import tensorflow as tf
......@@ -161,4 +162,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -7,6 +7,7 @@ import argparse
import numpy as np
import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import prediction_incorrect
from tensorpack.dataflow import dataset
......@@ -99,7 +100,7 @@ def get_config():
return TrainConfig(
model=Model(),
dataflow=data_train,
data=QueueInput(data_train),
callbacks=[
ModelSaver(),
InferenceRunner(data_test,
......@@ -125,4 +126,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
......@@ -43,12 +43,12 @@ def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
def launch_train_with_config(config, trainer):
"""
Train with a :class:`TrainConfig` and a new version of :class:`Trainer`, to
Train with a :class:`TrainConfig` and a :class:`Trainer`, to
mimic the old training interface.
Args:
config (TrainConfig):
trainer (Trainer): an instance of the new trainer
trainer (Trainer): an instance of a SingleCostTrainer
Examples:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment