Commit 17b34c69 authored by Yuxin Wu's avatar Yuxin Wu

Use SmartInit globally - a simpler interface to initialization

parent cbd698ad
...@@ -39,13 +39,13 @@ For inference, use `session_init` in `PredictConfig(...)`. ...@@ -39,13 +39,13 @@ For inference, use `session_init` in `PredictConfig(...)`.
There are a few ways a session can be initialized: There are a few ways a session can be initialized:
``` ```
session_init=SmartRestore("path/to/checkpoint") # load a TF checkpoint session_init=SmartInit("path/to/checkpoint") # load a TF checkpoint
session_init=SmartRestore("path/to/model_zoo.npz") # load tensorpack model zoo session_init=SmartInit("path/to/model_zoo.npz") # load tensorpack model zoo
session_init=SmartRestore(dict_of_parameters) # load a dictionary session_init=SmartInit(dict_of_parameters) # load a dictionary
session_init=SmartRestore(["path1", dict2]) # load them sequentially session_init=SmartInit(["path1", dict2]) # load them sequentially
``` ```
[SmartRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.SmartRestore) [SmartInit](../modules/tfutils.html#tensorpack.tfutils.sessinit.SmartInit)
is in fact a small helper which uses some heuristics to return you one of is in fact a small helper which uses some heuristics to return you one of
[SaverRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore) or [SaverRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore) or
[DictRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore). [DictRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore).
......
...@@ -265,7 +265,7 @@ def train(): ...@@ -265,7 +265,7 @@ def train():
], ],
session_creator=sesscreate.NewSessionCreator(config=get_default_sess_config(0.5)), session_creator=sesscreate.NewSessionCreator(config=get_default_sess_config(0.5)),
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
session_init=get_model_loader(args.load) if args.load else None, session_init=SmartInit(args.load),
max_epoch=1000, max_epoch=1000,
) )
trainer = SimpleTrainer() if num_gpu == 1 else AsyncMultiGPUTrainer(train_tower) trainer = SimpleTrainer() if num_gpu == 1 else AsyncMultiGPUTrainer(train_tower)
...@@ -294,7 +294,7 @@ if __name__ == '__main__': ...@@ -294,7 +294,7 @@ if __name__ == '__main__':
assert args.load is not None assert args.load is not None
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
model=Model(), model=Model(),
session_init=get_model_loader(args.load), session_init=SmartInit(args.load),
input_names=['state'], input_names=['state'],
output_names=['policy'])) output_names=['policy']))
if args.task == 'play': if args.task == 'play':
......
...@@ -119,6 +119,5 @@ if __name__ == '__main__': ...@@ -119,6 +119,5 @@ if __name__ == '__main__':
ds_test = get_data(args.test, False, args.stat) ds_test = get_data(args.test, False, args.stat)
config = get_config(ds_train, ds_test) config = get_config(ds_train, ds_test)
if args.load: config.session_init = SmartInit(args.load)
config.session_init = SaverRestore(args.load)
launch_train_with_config(config, SimpleTrainer()) launch_train_with_config(config, SimpleTrainer())
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import numpy as np
import os import os
import cv2 import cv2
import tensorflow as tf import tensorflow as tf
...@@ -39,11 +38,10 @@ def tower_func(image): ...@@ -39,11 +38,10 @@ def tower_func(image):
def run_test(path, input): def run_test(path, input):
param_dict = dict(np.load(path))
predictor = OfflinePredictor(PredictConfig( predictor = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 227, 227, 3), tf.float32, 'input')], input_signature=[tf.TensorSpec((None, 227, 227, 3), tf.float32, 'input')],
tower_func=tower_func, tower_func=tower_func,
session_init=DictRestore(param_dict), session_init=SmartInit(path),
input_names=['input'], input_names=['input'],
output_names=['prob'] output_names=['prob']
)) ))
......
...@@ -95,11 +95,10 @@ def CPM(image): ...@@ -95,11 +95,10 @@ def CPM(image):
def run_test(model_path, img_file): def run_test(model_path, img_file):
param_dict = dict(np.load(model_path))
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 368, 368, 3), tf.float32, 'input')], input_signature=[tf.TensorSpec((None, 368, 368, 3), tf.float32, 'input')],
tower_func=CPM, tower_func=CPM,
session_init=DictRestore(param_dict), session_init=SmartInit(model_path),
input_names=['input'], input_names=['input'],
output_names=['resized_map'] output_names=['resized_map']
)) ))
......
...@@ -61,7 +61,7 @@ def run_test(path, input): ...@@ -61,7 +61,7 @@ def run_test(path, input):
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')], input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')],
tower_func=tower_func, tower_func=tower_func,
session_init=DictRestore(param_dict), session_init=SmartInit(param_dict),
input_names=['input'], input_names=['input'],
output_names=['prob'] # prob:0 is the probability distribution output_names=['prob'] # prob:0 is the probability distribution
)) ))
......
...@@ -64,7 +64,7 @@ def run_test(path, input): ...@@ -64,7 +64,7 @@ def run_test(path, input):
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')], input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')],
tower_func=tower_func, tower_func=tower_func,
session_init=DictRestore(param_dict), session_init=SmartInit(param_dict),
input_names=['input'], input_names=['input'],
output_names=['prob'] # prob:0 is the probability distribution output_names=['prob'] # prob:0 is the probability distribution
)) ))
......
...@@ -141,7 +141,7 @@ def sample(path, start, length): ...@@ -141,7 +141,7 @@ def sample(path, start, length):
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
model=Model(), model=Model(),
session_init=SaverRestore(path), session_init=SmartInit(path),
input_names=['input', 'c0', 'h0', 'c1', 'h1'], input_names=['input', 'c0', 'h0', 'c1', 'h1'],
output_names=['prob', 'last_state'])) output_names=['prob', 'last_state']))
...@@ -193,6 +193,5 @@ if __name__ == '__main__': ...@@ -193,6 +193,5 @@ if __name__ == '__main__':
else: else:
param.corpus = args.corpus param.corpus = args.corpus
config = get_config() config = get_config()
if args.load: config.session_init = SmartInit(args.load)
config.session_init = SaverRestore(args.load)
launch_train_with_config(config, SimpleTrainer()) launch_train_with_config(config, SimpleTrainer())
...@@ -171,7 +171,7 @@ if __name__ == '__main__': ...@@ -171,7 +171,7 @@ if __name__ == '__main__':
assert args.load is not None assert args.load is not None
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
model=model, model=model,
session_init=get_model_loader(args.load), session_init=SmartInit(args.load),
input_names=['state'], input_names=['state'],
output_names=['Qvalue'])) output_names=['Qvalue']))
if args.task == 'play': if args.task == 'play':
...@@ -183,6 +183,5 @@ if __name__ == '__main__': ...@@ -183,6 +183,5 @@ if __name__ == '__main__':
os.path.join('train_log', 'DQN-{}'.format( os.path.join('train_log', 'DQN-{}'.format(
os.path.basename(args.env).split('.')[0]))) os.path.basename(args.env).split('.')[0])))
config = get_config(model) config = get_config(model)
if args.load: config.session_init = SmartInit(args.load)
config.session_init = get_model_loader(args.load)
launch_train_with_config(config, SimpleTrainer()) launch_train_with_config(config, SimpleTrainer())
...@@ -12,7 +12,7 @@ import tensorflow as tf ...@@ -12,7 +12,7 @@ import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils.sessinit import get_model_loader from tensorpack.tfutils.sessinit import SmartInit
from tensorpack.tfutils.summary import add_param_summary from tensorpack.tfutils.summary import add_param_summary
from tensorpack.tfutils.varreplace import remap_variables from tensorpack.tfutils.varreplace import remap_variables
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
...@@ -214,12 +214,12 @@ if __name__ == '__main__': ...@@ -214,12 +214,12 @@ if __name__ == '__main__':
if args.run: if args.run:
assert args.load.endswith('.npz') assert args.load.endswith('.npz')
run_image(Model(), DictRestore(dict(np.load(args.load))), args.run) run_image(Model(), SmartInit(args.load), args.run)
sys.exit() sys.exit()
if args.eval: if args.eval:
BATCH_SIZE = 128 BATCH_SIZE = 128
ds = get_data('val') ds = get_data('val')
eval_classification(Model(), get_model_loader(args.load), ds) eval_classification(Model(), SmartInit(args.load), ds)
sys.exit() sys.exit()
nr_tower = max(get_num_gpu(), 1) nr_tower = max(get_num_gpu(), 1)
...@@ -229,6 +229,5 @@ if __name__ == '__main__': ...@@ -229,6 +229,5 @@ if __name__ == '__main__':
logger.info("Batch per tower: {}".format(BATCH_SIZE)) logger.info("Batch per tower: {}".format(BATCH_SIZE))
config = get_config() config = get_config()
if args.load: config.session_init = SmartInit(args.load)
config.session_init = SaverRestore(args.load)
launch_train_with_config(config, SyncMultiGPUTrainerReplicated(nr_tower)) launch_train_with_config(config, SyncMultiGPUTrainerReplicated(nr_tower))
...@@ -163,7 +163,7 @@ if __name__ == '__main__': ...@@ -163,7 +163,7 @@ if __name__ == '__main__':
ds = dataset.ILSVRC12(args.data, 'val', shuffle=False) ds = dataset.ILSVRC12(args.data, 'val', shuffle=False)
ds = AugmentImageComponent(ds, get_inference_augmentor()) ds = AugmentImageComponent(ds, get_inference_augmentor())
ds = BatchData(ds, 192, remainder=True) ds = BatchData(ds, 192, remainder=True)
eval_classification(Model(), get_model_loader(args.load), ds) eval_classification(Model(), SmartInit(args.load), ds)
elif args.run: elif args.run:
assert args.load.endswith('.npz') assert args.load.endswith('.npz')
run_image(Model(), DictRestore(dict(np.load(args.load))), args.run) run_image(Model(), SmartInit(args.load), args.run)
...@@ -255,6 +255,5 @@ if __name__ == '__main__': ...@@ -255,6 +255,5 @@ if __name__ == '__main__':
with change_gpu(args.gpu): with change_gpu(args.gpu):
NGPU = len(args.gpu.split(',')) NGPU = len(args.gpu.split(','))
config = get_config() config = get_config()
if args.load: config.session_init = SmartInit(args.load)
config.session_init = SaverRestore(args.load)
launch_train_with_config(config, SyncMultiGPUTrainer(NGPU)) launch_train_with_config(config, SyncMultiGPUTrainer(NGPU))
...@@ -14,7 +14,7 @@ assert six.PY3, "This example requires Python 3!" ...@@ -14,7 +14,7 @@ assert six.PY3, "This example requires Python 3!"
import tensorpack.utils.viz as tpviz import tensorpack.utils.viz as tpviz
from tensorpack.predict import MultiTowerOfflinePredictor, OfflinePredictor, PredictConfig from tensorpack.predict import MultiTowerOfflinePredictor, OfflinePredictor, PredictConfig
from tensorpack.tfutils import get_model_loader, get_tf_version_tuple from tensorpack.tfutils import SmartInit, get_tf_version_tuple
from tensorpack.tfutils.export import ModelExporter from tensorpack.tfutils.export import ModelExporter
from tensorpack.utils import fs, logger from tensorpack.utils import fs, logger
...@@ -38,7 +38,7 @@ def do_visualize(model, model_path, nr_visualize=100, output_dir='output'): ...@@ -38,7 +38,7 @@ def do_visualize(model, model_path, nr_visualize=100, output_dir='output'):
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
model=model, model=model,
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
input_names=['image', 'gt_boxes', 'gt_labels'], input_names=['image', 'gt_boxes', 'gt_labels'],
output_names=[ output_names=[
'generate_{}_proposals/boxes'.format('fpn' if cfg.MODE_FPN else 'rpn'), 'generate_{}_proposals/boxes'.format('fpn' if cfg.MODE_FPN else 'rpn'),
...@@ -146,7 +146,7 @@ if __name__ == '__main__': ...@@ -146,7 +146,7 @@ if __name__ == '__main__':
else: else:
predcfg = PredictConfig( predcfg = PredictConfig(
model=MODEL, model=MODEL,
session_init=get_model_loader(args.load), session_init=SmartInit(args.load),
input_names=MODEL.get_inference_tensor_names()[0], input_names=MODEL.get_inference_tensor_names()[0],
output_names=MODEL.get_inference_tensor_names()[1]) output_names=MODEL.get_inference_tensor_names()[1])
......
...@@ -103,9 +103,9 @@ if __name__ == '__main__': ...@@ -103,9 +103,9 @@ if __name__ == '__main__':
else: else:
if args.load: if args.load:
# ignore mismatched values, so you can `--load` a model for fine-tuning # ignore mismatched values, so you can `--load` a model for fine-tuning
session_init = SmartRestore(args.load, ignore_mismatch=True) session_init = SmartInit(args.load, ignore_mismatch=True)
else: else:
session_init = SmartRestore(cfg.BACKBONE.WEIGHTS) session_init = SmartInit(cfg.BACKBONE.WEIGHTS)
traincfg = TrainConfig( traincfg = TrainConfig(
model=MODEL, model=MODEL,
......
...@@ -146,5 +146,5 @@ if __name__ == '__main__': ...@@ -146,5 +146,5 @@ if __name__ == '__main__':
StatMonitorParamSetter( StatMonitorParamSetter(
'learning_rate', 'losses/measure', lambda x: x * 0.5, 0, 10) 'learning_rate', 'losses/measure', lambda x: x * 0.5, 0, 10)
], ],
session_init=SaverRestore(args.load) if args.load else None, session_init=SmartInit(args.load),
steps_per_epoch=500, max_epoch=400) steps_per_epoch=500, max_epoch=400)
...@@ -114,7 +114,7 @@ def get_data(): ...@@ -114,7 +114,7 @@ def get_data():
def sample(model_path): def sample(model_path):
pred = PredictConfig( pred = PredictConfig(
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
model=Model(), model=Model(),
input_names=['label', 'z'], input_names=['label', 'z'],
output_names=['gen/gen']) output_names=['gen/gen'])
...@@ -145,5 +145,5 @@ if __name__ == '__main__': ...@@ -145,5 +145,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=500, steps_per_epoch=500,
max_epoch=100, max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load),
) )
...@@ -224,5 +224,5 @@ if __name__ == '__main__': ...@@ -224,5 +224,5 @@ if __name__ == '__main__':
], ],
max_epoch=195, max_epoch=195,
steps_per_epoch=data.size(), steps_per_epoch=data.size(),
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load)
) )
...@@ -121,7 +121,7 @@ def get_data(): ...@@ -121,7 +121,7 @@ def get_data():
def sample(model, model_path, output_name='gen/gen'): def sample(model, model_path, output_name='gen/gen'):
pred = PredictConfig( pred = PredictConfig(
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
model=model, model=model,
input_names=['z'], input_names=['z'],
output_names=[output_name, 'z']) output_names=[output_name, 'z'])
...@@ -167,5 +167,5 @@ if __name__ == '__main__': ...@@ -167,5 +167,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load),
) )
...@@ -211,5 +211,5 @@ if __name__ == '__main__': ...@@ -211,5 +211,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=250, max_epoch=250,
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load),
) )
...@@ -179,7 +179,7 @@ def get_data(): ...@@ -179,7 +179,7 @@ def get_data():
def sample(datadir, model_path): def sample(datadir, model_path):
pred = PredictConfig( pred = PredictConfig(
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
model=Model(), model=Model(),
input_names=['input', 'output'], input_names=['input', 'output'],
output_names=['viz']) output_names=['viz'])
...@@ -226,5 +226,5 @@ if __name__ == '__main__': ...@@ -226,5 +226,5 @@ if __name__ == '__main__':
], ],
steps_per_epoch=data.size(), steps_per_epoch=data.size(),
max_epoch=300, max_epoch=300,
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load)
) )
...@@ -97,5 +97,5 @@ if __name__ == '__main__': ...@@ -97,5 +97,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load)
) )
...@@ -218,7 +218,7 @@ def get_data(): ...@@ -218,7 +218,7 @@ def get_data():
def sample(model_path): def sample(model_path):
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
model=Model(), model=Model(),
input_names=['z_code', 'z_noise'], input_names=['z_code', 'z_noise'],
output_names=['gen/viz'])) output_names=['gen/viz']))
...@@ -276,5 +276,5 @@ if __name__ == '__main__': ...@@ -276,5 +276,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver(keep_checkpoint_every_n_hours=0.1)], callbacks=[ModelSaver(keep_checkpoint_every_n_hours=0.1)],
steps_per_epoch=500, steps_per_epoch=500,
max_epoch=100, max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load)
) )
...@@ -80,5 +80,5 @@ if __name__ == '__main__': ...@@ -80,5 +80,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver(), ClipCallback()], callbacks=[ModelSaver(), ClipCallback()],
steps_per_epoch=500, steps_per_epoch=500,
max_epoch=200, max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load)
) )
...@@ -271,7 +271,7 @@ def get_config(): ...@@ -271,7 +271,7 @@ def get_config():
def run(model_path, image_path, output): def run(model_path, image_path, output):
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
input_names=['image'], input_names=['image'],
output_names=['output' + str(k) for k in range(1, 7)]) output_names=['output' + str(k) for k in range(1, 7)])
predictor = OfflinePredictor(pred_config) predictor = OfflinePredictor(pred_config)
...@@ -309,8 +309,7 @@ if __name__ == '__main__': ...@@ -309,8 +309,7 @@ if __name__ == '__main__':
run(args.load, args.run, args.output) run(args.load, args.run, args.output)
else: else:
config = get_config() config = get_config()
if args.load: config.session_init = SmartInit(args.load)
config.session_init = get_model_loader(args.load)
launch_train_with_config( launch_train_with_config(
config, config,
SyncMultiGPUTrainer(max(get_num_gpu(), 1))) SyncMultiGPUTrainer(max(get_num_gpu(), 1)))
...@@ -431,7 +431,7 @@ class ImageNetModel(ModelDesc): ...@@ -431,7 +431,7 @@ class ImageNetModel(ModelDesc):
Examples: Examples:
pred = OfflinePredictor(model.create_predict_config(get_model_loader(args.load))) pred = OfflinePredictor(model.create_predict_config(SmartInit(args.load)))
prob = pred(NCHW_image)[0] # Nx1000 probabilities prob = pred(NCHW_image)[0] # Nx1000 probabilities
""" """
return PredictConfig(model=self, input_names=['input'], output_names=['prob'], session_init=session_init) return PredictConfig(model=self, input_names=['input'], output_names=['prob'], session_init=session_init)
......
...@@ -166,8 +166,7 @@ if __name__ == '__main__': ...@@ -166,8 +166,7 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config() config = get_config()
if args.load: config.session_init = SmartInit(args.load)
config.session_init = SaverRestore(args.load)
nr_tower = get_num_gpu() nr_tower = get_num_gpu()
assert nr_tower == NUM_GPU assert nr_tower == NUM_GPU
launch_train_with_config(config, SyncMultiGPUTrainer(NUM_GPU)) launch_train_with_config(config, SyncMultiGPUTrainer(NUM_GPU))
...@@ -11,7 +11,7 @@ import tensorflow as tf ...@@ -11,7 +11,7 @@ import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
from tensorpack.tfutils import argscope, get_model_loader, model_utils from tensorpack.tfutils import argscope, SmartInit, model_utils
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
...@@ -251,7 +251,7 @@ if __name__ == '__main__': ...@@ -251,7 +251,7 @@ if __name__ == '__main__':
if args.eval: if args.eval:
batch = 128 # something that can run on one gpu batch = 128 # something that can run on one gpu
ds = get_data('val', batch) ds = get_data('val', batch)
eval_classification(model, get_model_loader(args.load), ds) eval_classification(model, SmartInit(args.load), ds)
elif args.flops: elif args.flops:
# manually build the graph with batch=1 # manually build the graph with batch=1
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
...@@ -277,6 +277,5 @@ if __name__ == '__main__': ...@@ -277,6 +277,5 @@ if __name__ == '__main__':
nr_tower = max(get_num_gpu(), 1) nr_tower = max(get_num_gpu(), 1)
config = get_config(model, nr_tower) config = get_config(model, nr_tower)
if args.load: config.session_init = SmartInit(args.load)
config.session_init = get_model_loader(args.load)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower)) launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower))
...@@ -24,7 +24,7 @@ def apply(model, model_path, images, ground_truth=None): ...@@ -24,7 +24,7 @@ def apply(model, model_path, images, ground_truth=None):
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
model=model(height=newh, width=neww), model=model(height=newh, width=neww),
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
input_names=['left', 'right'], input_names=['left', 'right'],
output_names=['prediction'])) output_names=['prediction']))
...@@ -102,7 +102,7 @@ def inference(model, model_path, sintel_path): ...@@ -102,7 +102,7 @@ def inference(model, model_path, sintel_path):
pred = PredictConfig( pred = PredictConfig(
model=model(height=h, width=w), model=model(height=h, width=w),
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
input_names=['left', 'right', 'gt_flow'], input_names=['left', 'right', 'gt_flow'],
output_names=['epe', 'prediction']) output_names=['epe', 'prediction'])
pred = SimpleDatasetPredictor(pred, ds) pred = SimpleDatasetPredictor(pred, ds)
......
...@@ -174,6 +174,5 @@ if __name__ == '__main__': ...@@ -174,6 +174,5 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
config = get_config() config = get_config()
if args.load: config.session_init = SmartInit(args.load)
config.session_init = SaverRestore(args.load)
launch_train_with_config(config, SimpleTrainer()) launch_train_with_config(config, SimpleTrainer())
...@@ -149,6 +149,6 @@ if __name__ == '__main__': ...@@ -149,6 +149,6 @@ if __name__ == '__main__':
], ],
max_epoch=200, max_epoch=200,
steps_per_epoch=len(dataset_train), steps_per_epoch=len(dataset_train),
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load)
) )
launch_train_with_config(config, SimpleTrainer()) launch_train_with_config(config, SimpleTrainer())
...@@ -166,7 +166,7 @@ if __name__ == '__main__': ...@@ -166,7 +166,7 @@ if __name__ == '__main__':
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
], ],
max_epoch=400, max_epoch=400,
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load),
) )
num_gpu = max(get_num_gpu(), 1) num_gpu = max(get_num_gpu(), 1)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(num_gpu)) launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(num_gpu))
...@@ -9,7 +9,7 @@ from tensorpack import QueueInput, TFDatasetInput, logger ...@@ -9,7 +9,7 @@ from tensorpack import QueueInput, TFDatasetInput, logger
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.dataflow import FakeData from tensorpack.dataflow import FakeData
from tensorpack.models import * from tensorpack.models import *
from tensorpack.tfutils import argscope, get_model_loader from tensorpack.tfutils import argscope, SmartInit
from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
...@@ -136,7 +136,7 @@ if __name__ == '__main__': ...@@ -136,7 +136,7 @@ if __name__ == '__main__':
if args.eval: if args.eval:
batch = 128 # something that can run on one gpu batch = 128 # something that can run on one gpu
ds = get_imagenet_dataflow(args.data, 'val', batch) ds = get_imagenet_dataflow(args.data, 'val', batch)
eval_classification(model, get_model_loader(args.load), ds) eval_classification(model, SmartInit(args.load), ds)
else: else:
if args.fake: if args.fake:
logger.set_logger_dir(os.path.join('train_log', 'tmp'), 'd') logger.set_logger_dir(os.path.join('train_log', 'tmp'), 'd')
...@@ -147,7 +147,6 @@ if __name__ == '__main__': ...@@ -147,7 +147,6 @@ if __name__ == '__main__':
args.mode, args.depth, args.batch))) args.mode, args.depth, args.batch)))
config = get_config(model) config = get_config(model)
if args.load: config.session_init = SmartInit(args.load)
config.session_init = get_model_loader(args.load)
trainer = SyncMultiGPUTrainerReplicated(max(get_num_gpu(), 1)) trainer = SyncMultiGPUTrainerReplicated(max(get_num_gpu(), 1))
launch_train_with_config(config, trainer) launch_train_with_config(config, trainer)
...@@ -79,7 +79,7 @@ def get_inference_augmentor(): ...@@ -79,7 +79,7 @@ def get_inference_augmentor():
def run_test(params, input): def run_test(params, input):
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
session_init=DictRestore(params), session_init=SmartInit(params),
input_names=['input'], input_names=['input'],
output_names=['prob'] output_names=['prob']
) )
...@@ -172,6 +172,6 @@ if __name__ == '__main__': ...@@ -172,6 +172,6 @@ if __name__ == '__main__':
if args.eval: if args.eval:
ds = get_imagenet_dataflow(args.eval, 'val', 128, get_inference_augmentor()) ds = get_imagenet_dataflow(args.eval, 'val', 128, get_inference_augmentor())
eval_classification(Model(), DictRestore(param), ds) eval_classification(Model(), SmartRestore(param), ds)
elif args.input: elif args.input:
run_test(param, args.input) run_test(param, args.input)
...@@ -97,7 +97,7 @@ def viz_cam(model_file, data_dir): ...@@ -97,7 +97,7 @@ def viz_cam(model_file, data_dir):
ds = get_data('val') ds = get_data('val')
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
session_init=get_model_loader(model_file), session_init=SmartInit(model_file),
input_names=['input', 'label'], input_names=['input', 'label'],
output_names=['wrong-top1', 'group3new/bnlast/Relu', 'linearnew/W'], output_names=['wrong-top1', 'group3new/bnlast/Relu', 'linearnew/W'],
return_input=True return_input=True
...@@ -151,6 +151,5 @@ if __name__ == '__main__': ...@@ -151,6 +151,5 @@ if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir()
config = get_config() config = get_config()
if args.load: config.session_init = SmartInit(args.load)
config.session_init = get_model_loader(args.load)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(num_gpu)) launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(num_gpu))
...@@ -68,7 +68,7 @@ class Model(tp.ModelDescBase): ...@@ -68,7 +68,7 @@ class Model(tp.ModelDescBase):
def run(model_path, image_path): def run(model_path, image_path):
predictor = tp.OfflinePredictor(tp.PredictConfig( predictor = tp.OfflinePredictor(tp.PredictConfig(
model=Model(), model=Model(),
session_init=tp.get_model_loader(model_path), session_init=tp.SmartInit(model_path),
input_names=['image'], input_names=['image'],
output_names=['saliency'])) output_names=['saliency']))
im = cv2.imread(image_path) im = cv2.imread(image_path)
......
...@@ -364,7 +364,7 @@ def visualize(model_path, model, algo_name): ...@@ -364,7 +364,7 @@ def visualize(model_path, model, algo_name):
logger.error("visualize requires matplotlib package ...") logger.error("visualize requires matplotlib package ...")
return return
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
model=model(), model=model(),
input_names=['input'], input_names=['input'],
output_names=['emb'])) output_names=['emb']))
...@@ -432,7 +432,5 @@ if __name__ == '__main__': ...@@ -432,7 +432,5 @@ if __name__ == '__main__':
visualize(args.load, ALGO_CONFIGS[args.algorithm], args.algorithm) visualize(args.load, ALGO_CONFIGS[args.algorithm], args.algorithm)
else: else:
config = get_config(ALGO_CONFIGS[args.algorithm], args.algorithm) config = get_config(ALGO_CONFIGS[args.algorithm], args.algorithm)
if args.load: config.session_init = SmartInit(args.load)
config.session_init = SaverRestore(args.load) launch_train_with_config(config, SimpleTrainer())
else:
launch_train_with_config(config, SimpleTrainer())
...@@ -201,7 +201,7 @@ def get_data(isTrain): ...@@ -201,7 +201,7 @@ def get_data(isTrain):
def view_warp(modelpath): def view_warp(modelpath):
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(modelpath), session_init=SmartInit(modelpath),
model=Model(), model=Model(),
input_names=['input'], input_names=['input'],
output_names=['visualization/viz', 'STN1/affine', 'STN2/affine'])) output_names=['visualization/viz', 'STN1/affine', 'STN2/affine']))
...@@ -265,6 +265,5 @@ if __name__ == '__main__': ...@@ -265,6 +265,5 @@ if __name__ == '__main__':
view_warp(args.load) view_warp(args.load)
else: else:
config = get_config() config = get_config()
if args.load: config.session_init = SmartInit(args.load)
config.session_init = SaverRestore(args.load)
launch_train_with_config(config, SimpleTrainer()) launch_train_with_config(config, SimpleTrainer())
...@@ -227,7 +227,7 @@ def apply(model_path, lowres_path="", output_path='.'): ...@@ -227,7 +227,7 @@ def apply(model_path, lowres_path="", output_path='.'):
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
model=Model(LR_SIZE_H, LR_SIZE_W), model=Model(LR_SIZE_H, LR_SIZE_W),
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
input_names=['Ilr'], input_names=['Ilr'],
output_names=['prediction'])) output_names=['prediction']))
...@@ -279,12 +279,12 @@ if __name__ == '__main__': ...@@ -279,12 +279,12 @@ if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir()
if args.load: if args.load:
session_init = SaverRestore(args.load) session_init = SmartInit(args.load)
else: else:
assert os.path.isfile(args.vgg19) assert os.path.isfile(args.vgg19)
param_dict = dict(np.load(args.vgg19)) param_dict = dict(np.load(args.vgg19))
param_dict = {'VGG19/' + name: value for name, value in six.iteritems(param_dict)} param_dict = {'VGG19/' + name: value for name, value in six.iteritems(param_dict)}
session_init = DictRestore(param_dict) session_init = SmartInit(param_dict)
nr_tower = max(get_num_gpu(), 1) nr_tower = max(get_num_gpu(), 1)
data = QueueInput(get_data(args.data)) data = QueueInput(get_data(args.data))
......
...@@ -143,8 +143,7 @@ if __name__ == '__main__': ...@@ -143,8 +143,7 @@ if __name__ == '__main__':
with tf.Graph().as_default(): with tf.Graph().as_default():
logger.set_logger_dir(os.path.join('train_log', 'cifar' + str(args.classnum))) logger.set_logger_dir(os.path.join('train_log', 'cifar' + str(args.classnum)))
config = get_config(args.classnum) config = get_config(args.classnum)
if args.load: config.session_init = SmartInit(args.load)
config.session_init = SaverRestore(args.load)
num_gpu = get_num_gpu() num_gpu = get_num_gpu()
trainer = SimpleTrainer() if num_gpu <= 1 \ trainer = SimpleTrainer() if num_gpu <= 1 \
......
...@@ -106,7 +106,7 @@ class InferenceOnlyModel(Model): ...@@ -106,7 +106,7 @@ class InferenceOnlyModel(Model):
def export_serving(model_path): def export_serving(model_path):
"""Export trained model to use it in TensorFlow Serving or cloudML. """ """Export trained model to use it in TensorFlow Serving or cloudML. """
pred_config = PredictConfig( pred_config = PredictConfig(
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
model=InferenceOnlyModel(), model=InferenceOnlyModel(),
input_names=['input_img_bytes'], input_names=['input_img_bytes'],
output_names=['prediction_img_bytes']) output_names=['prediction_img_bytes'])
...@@ -117,7 +117,7 @@ def export_compact(model_path): ...@@ -117,7 +117,7 @@ def export_compact(model_path):
"""Export trained model to use it as a frozen and pruned inference graph in """Export trained model to use it as a frozen and pruned inference graph in
mobile applications. """ mobile applications. """
pred_config = PredictConfig( pred_config = PredictConfig(
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
model=Model(), model=Model(),
input_names=['input_img'], input_names=['input_img'],
output_names=['prediction_img']) output_names=['prediction_img'])
...@@ -127,7 +127,7 @@ def export_compact(model_path): ...@@ -127,7 +127,7 @@ def export_compact(model_path):
def apply(model_path): def apply(model_path):
"""Run inference from a training model checkpoint. """ """Run inference from a training model checkpoint. """
pred_config = PredictConfig( pred_config = PredictConfig(
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
model=Model(), model=Model(),
input_names=['input_img'], input_names=['input_img'],
output_names=['prediction_img']) output_names=['prediction_img'])
...@@ -141,7 +141,7 @@ def apply(model_path): ...@@ -141,7 +141,7 @@ def apply(model_path):
def apply_inference_graph(model_path): def apply_inference_graph(model_path):
"""Run inference from a different graph, which receives encoded images buffers. """ """Run inference from a different graph, which receives encoded images buffers. """
pred_config = PredictConfig( pred_config = PredictConfig(
session_init=get_model_loader(model_path), session_init=SmartInit(model_path),
model=InferenceOnlyModel(), model=InferenceOnlyModel(),
input_names=['input_img_bytes'], input_names=['input_img_bytes'],
output_names=['prediction_img_bytes']) output_names=['prediction_img_bytes'])
......
...@@ -107,6 +107,6 @@ if __name__ == '__main__': ...@@ -107,6 +107,6 @@ if __name__ == '__main__':
ScalarStats(['cost', 'accuracy'])) ScalarStats(['cost', 'accuracy']))
], ],
max_epoch=350, max_epoch=350,
session_init=SaverRestore(args.load) if args.load else None session_init=SmartInit(args.load)
) )
launch_train_with_config(config, SimpleTrainer()) launch_train_with_config(config, SimpleTrainer())
...@@ -70,8 +70,6 @@ if __name__ == '__main__': ...@@ -70,8 +70,6 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config() config = get_config()
config.session_init = SmartInit(args.load)
if args.load:
config.session_init = SaverRestore(args.load)
launch_train_with_config(config, SimpleTrainer()) launch_train_with_config(config, SimpleTrainer())
...@@ -12,7 +12,7 @@ from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varn ...@@ -12,7 +12,7 @@ from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varn
__all__ = ['SessionInit', 'ChainInit', __all__ = ['SessionInit', 'ChainInit',
'SaverRestore', 'SaverRestoreRelaxed', 'DictRestore', 'SaverRestore', 'SaverRestoreRelaxed', 'DictRestore',
'JustCurrentSession', 'get_model_loader', 'SmartRestore'] 'JustCurrentSession', 'get_model_loader', 'SmartInit']
class SessionInit(object): class SessionInit(object):
...@@ -260,7 +260,7 @@ class ChainInit(SessionInit): ...@@ -260,7 +260,7 @@ class ChainInit(SessionInit):
i._run_init(sess) i._run_init(sess)
def SmartRestore(obj, ignore_mismatch=False): def SmartInit(obj, ignore_mismatch=False):
""" """
Create a :class:`SessionInit` to be loaded to a session, Create a :class:`SessionInit` to be loaded to a session,
automatically from any supported objects, with some smart heuristics. automatically from any supported objects, with some smart heuristics.
...@@ -268,9 +268,9 @@ def SmartRestore(obj, ignore_mismatch=False): ...@@ -268,9 +268,9 @@ def SmartRestore(obj, ignore_mismatch=False):
+ A TF checkpoint + A TF checkpoint
+ A dict of numpy arrays + A dict of numpy arrays
+ A npz file + A npz file, to be interpreted as a dict
+ An empty string or None + An empty string or None, in which case the sessinit will be a no-op
+ A list of supported objects + A list of supported objects, to be initialized one by one
Args: Args:
obj: a supported object obj: a supported object
...@@ -285,7 +285,7 @@ def SmartRestore(obj, ignore_mismatch=False): ...@@ -285,7 +285,7 @@ def SmartRestore(obj, ignore_mismatch=False):
if not obj: if not obj:
return JustCurrentSession() return JustCurrentSession()
if isinstance(obj, list): if isinstance(obj, list):
return ChainInit([SmartRestore(x, ignore_mismatch=ignore_mismatch) for x in obj]) return ChainInit([SmartInit(x, ignore_mismatch=ignore_mismatch) for x in obj])
if isinstance(obj, six.string_types): if isinstance(obj, six.string_types):
obj = os.path.expanduser(obj) obj = os.path.expanduser(obj)
if obj.endswith(".npy") or obj.endswith(".npz"): if obj.endswith(".npy") or obj.endswith(".npz"):
...@@ -301,11 +301,11 @@ def SmartRestore(obj, ignore_mismatch=False): ...@@ -301,11 +301,11 @@ def SmartRestore(obj, ignore_mismatch=False):
# A TF checkpoint must be a prefix of an actual file. # A TF checkpoint must be a prefix of an actual file.
return (SaverRestoreRelaxed if ignore_mismatch else SaverRestore)(obj) return (SaverRestoreRelaxed if ignore_mismatch else SaverRestore)(obj)
else: else:
raise ValueError("Invalid argument to SmartRestore: " + obj) raise ValueError("Invalid argument to SmartInit: " + obj)
if isinstance(obj, dict): if isinstance(obj, dict):
return DictRestore(obj, ignore_mismatch=ignore_mismatch) return DictRestore(obj, ignore_mismatch=ignore_mismatch)
raise ValueError("Invalid argument to SmartRestore: " + type(obj)) raise ValueError("Invalid argument to SmartInit: " + type(obj))
get_model_loader = SmartRestore get_model_loader = SmartInit
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment