Commit 5beab907 authored by Yuxin Wu's avatar Yuxin Wu

[breaking] rename ParamRestore to DictRestore

parent 6ac34dfb
......@@ -12,6 +12,7 @@ for more details.
It you think:
1. The framework has limitation so your XYZ cannot be supported, OR
2. Your XYZ is very common, or very well-defined, so it would be nice to include it.
Then it's a good time to open an issue.
## How to dump/inspect a model
......@@ -25,7 +26,7 @@ expects a path without the extension.
You can dump a cleaner version of the model (with only model/trainable variables), with
`scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy format.
It expects a metagraph file which is also saved by `ModelSaver`.
The script expects a metagraph file which is also saved by `ModelSaver`.
## How to load a model / do transfer learning
......
......@@ -108,7 +108,7 @@ def run_test(model_path, img_file):
param_dict = np.load(model_path, encoding='latin1').item()
predict_func = OfflinePredictor(PredictConfig(
model=Model(),
session_init=ParamRestore(param_dict),
session_init=DictRestore(param_dict),
input_names=['input'],
output_names=['resized_map']
))
......
......@@ -308,7 +308,7 @@ if __name__ == '__main__':
if args.run:
assert args.load.endswith('.npy')
run_image(Model(), ParamRestore(np.load(args.load, encoding='latin1').item()), args.run)
run_image(Model(), DictRestore(np.load(args.load, encoding='latin1').item()), args.run)
sys.exit()
assert args.gpu is not None, "Need to specify a list of gpu for training!"
......
......@@ -190,5 +190,5 @@ if __name__ == '__main__':
eval_on_ILSVRC12(args.load, args.data)
elif args.run:
assert args.load.endswith('.npy')
run_image(Model(), ParamRestore(
run_image(Model(), DictRestore(
np.load(args.load, encoding='latin1').item()), args.run)
......@@ -114,7 +114,7 @@ def get_inference_augmentor():
def run_test(params, input):
pred_config = PredictConfig(
model=Model(),
session_init=ParamRestore(params),
session_init=DictRestore(params),
input_names=['input'],
output_names=['prob']
)
......@@ -139,7 +139,7 @@ def eval_on_ILSVRC12(params, data_dir):
ds = BatchData(ds, 128, remainder=True)
pred_config = PredictConfig(
model=Model(),
session_init=ParamRestore(params),
session_init=DictRestore(params),
input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5']
)
......
......@@ -56,7 +56,7 @@ def run_test(path, input):
param_dict = np.load(path, encoding='latin1').item()
predictor = OfflinePredictor(PredictConfig(
model=Model(),
session_init=ParamRestore(param_dict),
session_init=DictRestore(param_dict),
input_names=['input'],
output_names=['prob']
))
......
......@@ -66,7 +66,7 @@ def run_test(path, input):
param_dict = np.load(path, encoding='latin1').item()
predict_func = OfflinePredictor(PredictConfig(
model=Model(),
session_init=ParamRestore(param_dict),
session_init=DictRestore(param_dict),
input_names=['input'],
output_names=['prob'] # prob:0 is the probability distribution
))
......
......@@ -32,7 +32,7 @@ with tf.Graph().as_default() as G:
# loading...
if args.model.endswith('.npy'):
init = sessinit.ParamRestore(np.load(args.model).item())
init = sessinit.DictRestore(np.load(args.model).item())
else:
init = sessinit.SaverRestore(args.model)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
......
......@@ -123,7 +123,10 @@ class BatchData(ProxyDataFlow):
elif type(dt) == float:
tp = 'float32'
else:
tp = dt.dtype
try:
tp = dt.dtype
except:
raise TypeError("Unsupported type to batch: {}".format(type(dt)))
try:
result.append(
np.asarray([x[k] for x in data_holder], dtype=tp))
......
......@@ -4,8 +4,8 @@
import sys
import os
import cv2
import multiprocessing as mp
import cv2
from six.moves import range
from .base import DataFlow
......
......@@ -8,12 +8,13 @@ import tensorflow as tf
import six
from ..utils import logger
from ..utils.develop import deprecated
from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'SaverRestore', 'SaverRestoreRelaxed',
'ParamRestore', 'ChainInit',
'ParamRestore', 'DictRestore', 'ChainInit',
'JustCurrentSession', 'get_model_loader']
......@@ -156,7 +157,7 @@ class SaverRestoreRelaxed(SaverRestore):
self._match_vars(f)
class ParamRestore(SessionInit):
class DictRestore(SessionInit):
"""
Restore variables from a dictionary.
"""
......@@ -190,6 +191,11 @@ class ParamRestore(SessionInit):
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
@deprecated("Use `DictRestore` instead!", "2017-06-01")
def ParamRestore(d):
return DictRestore(d)
class ChainInit(SessionInit):
""" Initialize a session by a list of :class:`SessionInit` instance, executed one by one.
This can be useful for, e.g., loading several models from different files
......@@ -221,11 +227,11 @@ def get_model_loader(filename):
Get a corresponding model loader by looking at the file name.
Returns:
SessInit: either a :class:`ParamRestore` (if name ends with 'npy') or
SessInit: either a :class:`DictRestore` (if name ends with 'npy') or
:class:`SaverRestore` (otherwise).
"""
if filename.endswith('.npy'):
assert os.path.isfile(filename), filename
return ParamRestore(np.load(filename, encoding='latin1').item())
return DictRestore(np.load(filename, encoding='latin1').item())
else:
return SaverRestore(filename)
......@@ -119,7 +119,7 @@ class SessionUpdate(object):
def dump_session_params(path):
"""
Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy format (loadable by :class:`ParamRestore`).
npy format (loadable by :class:`DictRestore`).
Args:
path(str): the path to save the parameters.
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: argtools.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import inspect
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment