Commit 17b34c69 authored by Yuxin Wu's avatar Yuxin Wu

Use SmartInit globally - a simpler interface to initialization

parent cbd698ad
......@@ -39,13 +39,13 @@ For inference, use `session_init` in `PredictConfig(...)`.
There are a few ways a session can be initialized:
```
session_init=SmartRestore("path/to/checkpoint") # load a TF checkpoint
session_init=SmartRestore("path/to/model_zoo.npz") # load tensorpack model zoo
session_init=SmartRestore(dict_of_parameters) # load a dictionary
session_init=SmartRestore(["path1", dict2]) # load them sequentially
session_init=SmartInit("path/to/checkpoint") # load a TF checkpoint
session_init=SmartInit("path/to/model_zoo.npz") # load tensorpack model zoo
session_init=SmartInit(dict_of_parameters) # load a dictionary
session_init=SmartInit(["path1", dict2]) # load them sequentially
```
[SmartRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.SmartRestore)
[SmartInit](../modules/tfutils.html#tensorpack.tfutils.sessinit.SmartInit)
is in fact a small helper which uses some heuristics to return you one of
[SaverRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore) or
[DictRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore).
......
......@@ -265,7 +265,7 @@ def train():
],
session_creator=sesscreate.NewSessionCreator(config=get_default_sess_config(0.5)),
steps_per_epoch=STEPS_PER_EPOCH,
session_init=get_model_loader(args.load) if args.load else None,
session_init=SmartInit(args.load),
max_epoch=1000,
)
trainer = SimpleTrainer() if num_gpu == 1 else AsyncMultiGPUTrainer(train_tower)
......@@ -294,7 +294,7 @@ if __name__ == '__main__':
assert args.load is not None
pred = OfflinePredictor(PredictConfig(
model=Model(),
session_init=get_model_loader(args.load),
session_init=SmartInit(args.load),
input_names=['state'],
output_names=['policy']))
if args.task == 'play':
......
......@@ -119,6 +119,5 @@ if __name__ == '__main__':
ds_test = get_data(args.test, False, args.stat)
config = get_config(ds_train, ds_test)
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SimpleTrainer())
......@@ -5,7 +5,6 @@
from __future__ import print_function
import argparse
import numpy as np
import os
import cv2
import tensorflow as tf
......@@ -39,11 +38,10 @@ def tower_func(image):
def run_test(path, input):
param_dict = dict(np.load(path))
predictor = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 227, 227, 3), tf.float32, 'input')],
tower_func=tower_func,
session_init=DictRestore(param_dict),
session_init=SmartInit(path),
input_names=['input'],
output_names=['prob']
))
......
......@@ -95,11 +95,10 @@ def CPM(image):
def run_test(model_path, img_file):
param_dict = dict(np.load(model_path))
predict_func = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 368, 368, 3), tf.float32, 'input')],
tower_func=CPM,
session_init=DictRestore(param_dict),
session_init=SmartInit(model_path),
input_names=['input'],
output_names=['resized_map']
))
......
......@@ -61,7 +61,7 @@ def run_test(path, input):
predict_func = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')],
tower_func=tower_func,
session_init=DictRestore(param_dict),
session_init=SmartInit(param_dict),
input_names=['input'],
output_names=['prob'] # prob:0 is the probability distribution
))
......
......@@ -64,7 +64,7 @@ def run_test(path, input):
predict_func = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')],
tower_func=tower_func,
session_init=DictRestore(param_dict),
session_init=SmartInit(param_dict),
input_names=['input'],
output_names=['prob'] # prob:0 is the probability distribution
))
......
......@@ -141,7 +141,7 @@ def sample(path, start, length):
pred = OfflinePredictor(PredictConfig(
model=Model(),
session_init=SaverRestore(path),
session_init=SmartInit(path),
input_names=['input', 'c0', 'h0', 'c1', 'h1'],
output_names=['prob', 'last_state']))
......@@ -193,6 +193,5 @@ if __name__ == '__main__':
else:
param.corpus = args.corpus
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SimpleTrainer())
......@@ -171,7 +171,7 @@ if __name__ == '__main__':
assert args.load is not None
pred = OfflinePredictor(PredictConfig(
model=model,
session_init=get_model_loader(args.load),
session_init=SmartInit(args.load),
input_names=['state'],
output_names=['Qvalue']))
if args.task == 'play':
......@@ -183,6 +183,5 @@ if __name__ == '__main__':
os.path.join('train_log', 'DQN-{}'.format(
os.path.basename(args.env).split('.')[0])))
config = get_config(model)
if args.load:
config.session_init = get_model_loader(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SimpleTrainer())
......@@ -12,7 +12,7 @@ import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.sessinit import get_model_loader
from tensorpack.tfutils.sessinit import SmartInit
from tensorpack.tfutils.summary import add_param_summary
from tensorpack.tfutils.varreplace import remap_variables
from tensorpack.utils.gpu import get_num_gpu
......@@ -214,12 +214,12 @@ if __name__ == '__main__':
if args.run:
assert args.load.endswith('.npz')
run_image(Model(), DictRestore(dict(np.load(args.load))), args.run)
run_image(Model(), SmartInit(args.load), args.run)
sys.exit()
if args.eval:
BATCH_SIZE = 128
ds = get_data('val')
eval_classification(Model(), get_model_loader(args.load), ds)
eval_classification(Model(), SmartInit(args.load), ds)
sys.exit()
nr_tower = max(get_num_gpu(), 1)
......@@ -229,6 +229,5 @@ if __name__ == '__main__':
logger.info("Batch per tower: {}".format(BATCH_SIZE))
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SyncMultiGPUTrainerReplicated(nr_tower))
......@@ -163,7 +163,7 @@ if __name__ == '__main__':
ds = dataset.ILSVRC12(args.data, 'val', shuffle=False)
ds = AugmentImageComponent(ds, get_inference_augmentor())
ds = BatchData(ds, 192, remainder=True)
eval_classification(Model(), get_model_loader(args.load), ds)
eval_classification(Model(), SmartInit(args.load), ds)
elif args.run:
assert args.load.endswith('.npz')
run_image(Model(), DictRestore(dict(np.load(args.load))), args.run)
run_image(Model(), SmartInit(args.load), args.run)
......@@ -255,6 +255,5 @@ if __name__ == '__main__':
with change_gpu(args.gpu):
NGPU = len(args.gpu.split(','))
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SyncMultiGPUTrainer(NGPU))
......@@ -14,7 +14,7 @@ assert six.PY3, "This example requires Python 3!"
import tensorpack.utils.viz as tpviz
from tensorpack.predict import MultiTowerOfflinePredictor, OfflinePredictor, PredictConfig
from tensorpack.tfutils import get_model_loader, get_tf_version_tuple
from tensorpack.tfutils import SmartInit, get_tf_version_tuple
from tensorpack.tfutils.export import ModelExporter
from tensorpack.utils import fs, logger
......@@ -38,7 +38,7 @@ def do_visualize(model, model_path, nr_visualize=100, output_dir='output'):
pred = OfflinePredictor(PredictConfig(
model=model,
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
input_names=['image', 'gt_boxes', 'gt_labels'],
output_names=[
'generate_{}_proposals/boxes'.format('fpn' if cfg.MODE_FPN else 'rpn'),
......@@ -146,7 +146,7 @@ if __name__ == '__main__':
else:
predcfg = PredictConfig(
model=MODEL,
session_init=get_model_loader(args.load),
session_init=SmartInit(args.load),
input_names=MODEL.get_inference_tensor_names()[0],
output_names=MODEL.get_inference_tensor_names()[1])
......
......@@ -103,9 +103,9 @@ if __name__ == '__main__':
else:
if args.load:
# ignore mismatched values, so you can `--load` a model for fine-tuning
session_init = SmartRestore(args.load, ignore_mismatch=True)
session_init = SmartInit(args.load, ignore_mismatch=True)
else:
session_init = SmartRestore(cfg.BACKBONE.WEIGHTS)
session_init = SmartInit(cfg.BACKBONE.WEIGHTS)
traincfg = TrainConfig(
model=MODEL,
......
......@@ -146,5 +146,5 @@ if __name__ == '__main__':
StatMonitorParamSetter(
'learning_rate', 'losses/measure', lambda x: x * 0.5, 0, 10)
],
session_init=SaverRestore(args.load) if args.load else None,
session_init=SmartInit(args.load),
steps_per_epoch=500, max_epoch=400)
......@@ -114,7 +114,7 @@ def get_data():
def sample(model_path):
pred = PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=Model(),
input_names=['label', 'z'],
output_names=['gen/gen'])
......@@ -145,5 +145,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver()],
steps_per_epoch=500,
max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load),
)
......@@ -224,5 +224,5 @@ if __name__ == '__main__':
],
max_epoch=195,
steps_per_epoch=data.size(),
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
......@@ -121,7 +121,7 @@ def get_data():
def sample(model, model_path, output_name='gen/gen'):
pred = PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=model,
input_names=['z'],
output_names=[output_name, 'z'])
......@@ -167,5 +167,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load),
)
......@@ -211,5 +211,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=250,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load),
)
......@@ -179,7 +179,7 @@ def get_data():
def sample(datadir, model_path):
pred = PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=Model(),
input_names=['input', 'output'],
output_names=['viz'])
......@@ -226,5 +226,5 @@ if __name__ == '__main__':
],
steps_per_epoch=data.size(),
max_epoch=300,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
......@@ -97,5 +97,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
......@@ -218,7 +218,7 @@ def get_data():
def sample(model_path):
pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=Model(),
input_names=['z_code', 'z_noise'],
output_names=['gen/viz']))
......@@ -276,5 +276,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver(keep_checkpoint_every_n_hours=0.1)],
steps_per_epoch=500,
max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
......@@ -80,5 +80,5 @@ if __name__ == '__main__':
callbacks=[ModelSaver(), ClipCallback()],
steps_per_epoch=500,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
......@@ -271,7 +271,7 @@ def get_config():
def run(model_path, image_path, output):
pred_config = PredictConfig(
model=Model(),
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
input_names=['image'],
output_names=['output' + str(k) for k in range(1, 7)])
predictor = OfflinePredictor(pred_config)
......@@ -309,8 +309,7 @@ if __name__ == '__main__':
run(args.load, args.run, args.output)
else:
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(
config,
SyncMultiGPUTrainer(max(get_num_gpu(), 1)))
......@@ -431,7 +431,7 @@ class ImageNetModel(ModelDesc):
Examples:
pred = OfflinePredictor(model.create_predict_config(get_model_loader(args.load)))
pred = OfflinePredictor(model.create_predict_config(SmartInit(args.load)))
prob = pred(NCHW_image)[0] # Nx1000 probabilities
"""
return PredictConfig(model=self, input_names=['input'], output_names=['prob'], session_init=session_init)
......
......@@ -166,8 +166,7 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
nr_tower = get_num_gpu()
assert nr_tower == NUM_GPU
launch_train_with_config(config, SyncMultiGPUTrainer(NUM_GPU))
......@@ -11,7 +11,7 @@ import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import imgaug
from tensorpack.tfutils import argscope, get_model_loader, model_utils
from tensorpack.tfutils import argscope, SmartInit, model_utils
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu
......@@ -251,7 +251,7 @@ if __name__ == '__main__':
if args.eval:
batch = 128 # something that can run on one gpu
ds = get_data('val', batch)
eval_classification(model, get_model_loader(args.load), ds)
eval_classification(model, SmartInit(args.load), ds)
elif args.flops:
# manually build the graph with batch=1
with TowerContext('', is_training=False):
......@@ -277,6 +277,5 @@ if __name__ == '__main__':
nr_tower = max(get_num_gpu(), 1)
config = get_config(model, nr_tower)
if args.load:
config.session_init = get_model_loader(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower))
......@@ -24,7 +24,7 @@ def apply(model, model_path, images, ground_truth=None):
predict_func = OfflinePredictor(PredictConfig(
model=model(height=newh, width=neww),
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
input_names=['left', 'right'],
output_names=['prediction']))
......@@ -102,7 +102,7 @@ def inference(model, model_path, sintel_path):
pred = PredictConfig(
model=model(height=h, width=w),
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
input_names=['left', 'right', 'gt_flow'],
output_names=['epe', 'prediction'])
pred = SimpleDatasetPredictor(pred, ds)
......
......@@ -174,6 +174,5 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SimpleTrainer())
......@@ -149,6 +149,6 @@ if __name__ == '__main__':
],
max_epoch=200,
steps_per_epoch=len(dataset_train),
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
launch_train_with_config(config, SimpleTrainer())
......@@ -166,7 +166,7 @@ if __name__ == '__main__':
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
],
max_epoch=400,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load),
)
num_gpu = max(get_num_gpu(), 1)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(num_gpu))
......@@ -9,7 +9,7 @@ from tensorpack import QueueInput, TFDatasetInput, logger
from tensorpack.callbacks import *
from tensorpack.dataflow import FakeData
from tensorpack.models import *
from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.tfutils import argscope, SmartInit
from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
from tensorpack.utils.gpu import get_num_gpu
......@@ -136,7 +136,7 @@ if __name__ == '__main__':
if args.eval:
batch = 128 # something that can run on one gpu
ds = get_imagenet_dataflow(args.data, 'val', batch)
eval_classification(model, get_model_loader(args.load), ds)
eval_classification(model, SmartInit(args.load), ds)
else:
if args.fake:
logger.set_logger_dir(os.path.join('train_log', 'tmp'), 'd')
......@@ -147,7 +147,6 @@ if __name__ == '__main__':
args.mode, args.depth, args.batch)))
config = get_config(model)
if args.load:
config.session_init = get_model_loader(args.load)
config.session_init = SmartInit(args.load)
trainer = SyncMultiGPUTrainerReplicated(max(get_num_gpu(), 1))
launch_train_with_config(config, trainer)
......@@ -79,7 +79,7 @@ def get_inference_augmentor():
def run_test(params, input):
pred_config = PredictConfig(
model=Model(),
session_init=DictRestore(params),
session_init=SmartInit(params),
input_names=['input'],
output_names=['prob']
)
......@@ -172,6 +172,6 @@ if __name__ == '__main__':
if args.eval:
ds = get_imagenet_dataflow(args.eval, 'val', 128, get_inference_augmentor())
eval_classification(Model(), DictRestore(param), ds)
eval_classification(Model(), SmartRestore(param), ds)
elif args.input:
run_test(param, args.input)
......@@ -97,7 +97,7 @@ def viz_cam(model_file, data_dir):
ds = get_data('val')
pred_config = PredictConfig(
model=Model(),
session_init=get_model_loader(model_file),
session_init=SmartInit(model_file),
input_names=['input', 'label'],
output_names=['wrong-top1', 'group3new/bnlast/Relu', 'linearnew/W'],
return_input=True
......@@ -151,6 +151,5 @@ if __name__ == '__main__':
logger.auto_set_dir()
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(num_gpu))
......@@ -68,7 +68,7 @@ class Model(tp.ModelDescBase):
def run(model_path, image_path):
predictor = tp.OfflinePredictor(tp.PredictConfig(
model=Model(),
session_init=tp.get_model_loader(model_path),
session_init=tp.SmartInit(model_path),
input_names=['image'],
output_names=['saliency']))
im = cv2.imread(image_path)
......
......@@ -364,7 +364,7 @@ def visualize(model_path, model, algo_name):
logger.error("visualize requires matplotlib package ...")
return
pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=model(),
input_names=['input'],
output_names=['emb']))
......@@ -432,7 +432,5 @@ if __name__ == '__main__':
visualize(args.load, ALGO_CONFIGS[args.algorithm], args.algorithm)
else:
config = get_config(ALGO_CONFIGS[args.algorithm], args.algorithm)
if args.load:
config.session_init = SaverRestore(args.load)
else:
launch_train_with_config(config, SimpleTrainer())
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SimpleTrainer())
......@@ -201,7 +201,7 @@ def get_data(isTrain):
def view_warp(modelpath):
pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(modelpath),
session_init=SmartInit(modelpath),
model=Model(),
input_names=['input'],
output_names=['visualization/viz', 'STN1/affine', 'STN2/affine']))
......@@ -265,6 +265,5 @@ if __name__ == '__main__':
view_warp(args.load)
else:
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SimpleTrainer())
......@@ -227,7 +227,7 @@ def apply(model_path, lowres_path="", output_path='.'):
predict_func = OfflinePredictor(PredictConfig(
model=Model(LR_SIZE_H, LR_SIZE_W),
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
input_names=['Ilr'],
output_names=['prediction']))
......@@ -279,12 +279,12 @@ if __name__ == '__main__':
logger.auto_set_dir()
if args.load:
session_init = SaverRestore(args.load)
session_init = SmartInit(args.load)
else:
assert os.path.isfile(args.vgg19)
param_dict = dict(np.load(args.vgg19))
param_dict = {'VGG19/' + name: value for name, value in six.iteritems(param_dict)}
session_init = DictRestore(param_dict)
session_init = SmartInit(param_dict)
nr_tower = max(get_num_gpu(), 1)
data = QueueInput(get_data(args.data))
......
......@@ -143,8 +143,7 @@ if __name__ == '__main__':
with tf.Graph().as_default():
logger.set_logger_dir(os.path.join('train_log', 'cifar' + str(args.classnum)))
config = get_config(args.classnum)
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
num_gpu = get_num_gpu()
trainer = SimpleTrainer() if num_gpu <= 1 \
......
......@@ -106,7 +106,7 @@ class InferenceOnlyModel(Model):
def export_serving(model_path):
"""Export trained model to use it in TensorFlow Serving or cloudML. """
pred_config = PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=InferenceOnlyModel(),
input_names=['input_img_bytes'],
output_names=['prediction_img_bytes'])
......@@ -117,7 +117,7 @@ def export_compact(model_path):
"""Export trained model to use it as a frozen and pruned inference graph in
mobile applications. """
pred_config = PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=Model(),
input_names=['input_img'],
output_names=['prediction_img'])
......@@ -127,7 +127,7 @@ def export_compact(model_path):
def apply(model_path):
"""Run inference from a training model checkpoint. """
pred_config = PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=Model(),
input_names=['input_img'],
output_names=['prediction_img'])
......@@ -141,7 +141,7 @@ def apply(model_path):
def apply_inference_graph(model_path):
"""Run inference from a different graph, which receives encoded images buffers. """
pred_config = PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=InferenceOnlyModel(),
input_names=['input_img_bytes'],
output_names=['prediction_img_bytes'])
......
......@@ -107,6 +107,6 @@ if __name__ == '__main__':
ScalarStats(['cost', 'accuracy']))
],
max_epoch=350,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
launch_train_with_config(config, SimpleTrainer())
......@@ -70,8 +70,6 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SimpleTrainer())
......@@ -12,7 +12,7 @@ from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varn
__all__ = ['SessionInit', 'ChainInit',
'SaverRestore', 'SaverRestoreRelaxed', 'DictRestore',
'JustCurrentSession', 'get_model_loader', 'SmartRestore']
'JustCurrentSession', 'get_model_loader', 'SmartInit']
class SessionInit(object):
......@@ -260,7 +260,7 @@ class ChainInit(SessionInit):
i._run_init(sess)
def SmartRestore(obj, ignore_mismatch=False):
def SmartInit(obj, ignore_mismatch=False):
"""
Create a :class:`SessionInit` to be loaded to a session,
automatically from any supported objects, with some smart heuristics.
......@@ -268,9 +268,9 @@ def SmartRestore(obj, ignore_mismatch=False):
+ A TF checkpoint
+ A dict of numpy arrays
+ A npz file
+ An empty string or None
+ A list of supported objects
+ A npz file, to be interpreted as a dict
+ An empty string or None, in which case the sessinit will be a no-op
+ A list of supported objects, to be initialized one by one
Args:
obj: a supported object
......@@ -285,7 +285,7 @@ def SmartRestore(obj, ignore_mismatch=False):
if not obj:
return JustCurrentSession()
if isinstance(obj, list):
return ChainInit([SmartRestore(x, ignore_mismatch=ignore_mismatch) for x in obj])
return ChainInit([SmartInit(x, ignore_mismatch=ignore_mismatch) for x in obj])
if isinstance(obj, six.string_types):
obj = os.path.expanduser(obj)
if obj.endswith(".npy") or obj.endswith(".npz"):
......@@ -301,11 +301,11 @@ def SmartRestore(obj, ignore_mismatch=False):
# A TF checkpoint must be a prefix of an actual file.
return (SaverRestoreRelaxed if ignore_mismatch else SaverRestore)(obj)
else:
raise ValueError("Invalid argument to SmartRestore: " + obj)
raise ValueError("Invalid argument to SmartInit: " + obj)
if isinstance(obj, dict):
return DictRestore(obj, ignore_mismatch=ignore_mismatch)
raise ValueError("Invalid argument to SmartRestore: " + type(obj))
raise ValueError("Invalid argument to SmartInit: " + type(obj))
get_model_loader = SmartRestore
get_model_loader = SmartInit
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment