Commit 5beab907 authored by Yuxin Wu's avatar Yuxin Wu

[breaking] rename ParamRestore to DictRestore

parent 6ac34dfb
...@@ -12,6 +12,7 @@ for more details. ...@@ -12,6 +12,7 @@ for more details.
It you think: It you think:
1. The framework has limitation so your XYZ cannot be supported, OR 1. The framework has limitation so your XYZ cannot be supported, OR
2. Your XYZ is very common, or very well-defined, so it would be nice to include it. 2. Your XYZ is very common, or very well-defined, so it would be nice to include it.
Then it's a good time to open an issue. Then it's a good time to open an issue.
## How to dump/inspect a model ## How to dump/inspect a model
...@@ -25,7 +26,7 @@ expects a path without the extension. ...@@ -25,7 +26,7 @@ expects a path without the extension.
You can dump a cleaner version of the model (with only model/trainable variables), with You can dump a cleaner version of the model (with only model/trainable variables), with
`scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy format. `scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy format.
It expects a metagraph file which is also saved by `ModelSaver`. The script expects a metagraph file which is also saved by `ModelSaver`.
## How to load a model / do transfer learning ## How to load a model / do transfer learning
......
...@@ -108,7 +108,7 @@ def run_test(model_path, img_file): ...@@ -108,7 +108,7 @@ def run_test(model_path, img_file):
param_dict = np.load(model_path, encoding='latin1').item() param_dict = np.load(model_path, encoding='latin1').item()
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
model=Model(), model=Model(),
session_init=ParamRestore(param_dict), session_init=DictRestore(param_dict),
input_names=['input'], input_names=['input'],
output_names=['resized_map'] output_names=['resized_map']
)) ))
......
...@@ -308,7 +308,7 @@ if __name__ == '__main__': ...@@ -308,7 +308,7 @@ if __name__ == '__main__':
if args.run: if args.run:
assert args.load.endswith('.npy') assert args.load.endswith('.npy')
run_image(Model(), ParamRestore(np.load(args.load, encoding='latin1').item()), args.run) run_image(Model(), DictRestore(np.load(args.load, encoding='latin1').item()), args.run)
sys.exit() sys.exit()
assert args.gpu is not None, "Need to specify a list of gpu for training!" assert args.gpu is not None, "Need to specify a list of gpu for training!"
......
...@@ -190,5 +190,5 @@ if __name__ == '__main__': ...@@ -190,5 +190,5 @@ if __name__ == '__main__':
eval_on_ILSVRC12(args.load, args.data) eval_on_ILSVRC12(args.load, args.data)
elif args.run: elif args.run:
assert args.load.endswith('.npy') assert args.load.endswith('.npy')
run_image(Model(), ParamRestore( run_image(Model(), DictRestore(
np.load(args.load, encoding='latin1').item()), args.run) np.load(args.load, encoding='latin1').item()), args.run)
...@@ -114,7 +114,7 @@ def get_inference_augmentor(): ...@@ -114,7 +114,7 @@ def get_inference_augmentor():
def run_test(params, input): def run_test(params, input):
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
session_init=ParamRestore(params), session_init=DictRestore(params),
input_names=['input'], input_names=['input'],
output_names=['prob'] output_names=['prob']
) )
...@@ -139,7 +139,7 @@ def eval_on_ILSVRC12(params, data_dir): ...@@ -139,7 +139,7 @@ def eval_on_ILSVRC12(params, data_dir):
ds = BatchData(ds, 128, remainder=True) ds = BatchData(ds, 128, remainder=True)
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
session_init=ParamRestore(params), session_init=DictRestore(params),
input_names=['input', 'label'], input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5'] output_names=['wrong-top1', 'wrong-top5']
) )
......
...@@ -56,7 +56,7 @@ def run_test(path, input): ...@@ -56,7 +56,7 @@ def run_test(path, input):
param_dict = np.load(path, encoding='latin1').item() param_dict = np.load(path, encoding='latin1').item()
predictor = OfflinePredictor(PredictConfig( predictor = OfflinePredictor(PredictConfig(
model=Model(), model=Model(),
session_init=ParamRestore(param_dict), session_init=DictRestore(param_dict),
input_names=['input'], input_names=['input'],
output_names=['prob'] output_names=['prob']
)) ))
......
...@@ -66,7 +66,7 @@ def run_test(path, input): ...@@ -66,7 +66,7 @@ def run_test(path, input):
param_dict = np.load(path, encoding='latin1').item() param_dict = np.load(path, encoding='latin1').item()
predict_func = OfflinePredictor(PredictConfig( predict_func = OfflinePredictor(PredictConfig(
model=Model(), model=Model(),
session_init=ParamRestore(param_dict), session_init=DictRestore(param_dict),
input_names=['input'], input_names=['input'],
output_names=['prob'] # prob:0 is the probability distribution output_names=['prob'] # prob:0 is the probability distribution
)) ))
......
...@@ -32,7 +32,7 @@ with tf.Graph().as_default() as G: ...@@ -32,7 +32,7 @@ with tf.Graph().as_default() as G:
# loading... # loading...
if args.model.endswith('.npy'): if args.model.endswith('.npy'):
init = sessinit.ParamRestore(np.load(args.model).item()) init = sessinit.DictRestore(np.load(args.model).item())
else: else:
init = sessinit.SaverRestore(args.model) init = sessinit.SaverRestore(args.model)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
......
...@@ -123,7 +123,10 @@ class BatchData(ProxyDataFlow): ...@@ -123,7 +123,10 @@ class BatchData(ProxyDataFlow):
elif type(dt) == float: elif type(dt) == float:
tp = 'float32' tp = 'float32'
else: else:
try:
tp = dt.dtype tp = dt.dtype
except:
raise TypeError("Unsupported type to batch: {}".format(type(dt)))
try: try:
result.append( result.append(
np.asarray([x[k] for x in data_holder], dtype=tp)) np.asarray([x[k] for x in data_holder], dtype=tp))
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import sys import sys
import os import os
import cv2
import multiprocessing as mp import multiprocessing as mp
import cv2
from six.moves import range from six.moves import range
from .base import DataFlow from .base import DataFlow
......
...@@ -8,12 +8,13 @@ import tensorflow as tf ...@@ -8,12 +8,13 @@ import tensorflow as tf
import six import six
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated
from .common import get_op_tensor_name from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname, from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path) is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'SaverRestore', 'SaverRestoreRelaxed', __all__ = ['SessionInit', 'SaverRestore', 'SaverRestoreRelaxed',
'ParamRestore', 'ChainInit', 'ParamRestore', 'DictRestore', 'ChainInit',
'JustCurrentSession', 'get_model_loader'] 'JustCurrentSession', 'get_model_loader']
...@@ -156,7 +157,7 @@ class SaverRestoreRelaxed(SaverRestore): ...@@ -156,7 +157,7 @@ class SaverRestoreRelaxed(SaverRestore):
self._match_vars(f) self._match_vars(f)
class ParamRestore(SessionInit): class DictRestore(SessionInit):
""" """
Restore variables from a dictionary. Restore variables from a dictionary.
""" """
...@@ -190,6 +191,11 @@ class ParamRestore(SessionInit): ...@@ -190,6 +191,11 @@ class ParamRestore(SessionInit):
upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect}) upd.update({name: value for name, value in six.iteritems(self.prms) if name in intersect})
@deprecated("Use `DictRestore` instead!", "2017-06-01")
def ParamRestore(d):
return DictRestore(d)
class ChainInit(SessionInit): class ChainInit(SessionInit):
""" Initialize a session by a list of :class:`SessionInit` instance, executed one by one. """ Initialize a session by a list of :class:`SessionInit` instance, executed one by one.
This can be useful for, e.g., loading several models from different files This can be useful for, e.g., loading several models from different files
...@@ -221,11 +227,11 @@ def get_model_loader(filename): ...@@ -221,11 +227,11 @@ def get_model_loader(filename):
Get a corresponding model loader by looking at the file name. Get a corresponding model loader by looking at the file name.
Returns: Returns:
SessInit: either a :class:`ParamRestore` (if name ends with 'npy') or SessInit: either a :class:`DictRestore` (if name ends with 'npy') or
:class:`SaverRestore` (otherwise). :class:`SaverRestore` (otherwise).
""" """
if filename.endswith('.npy'): if filename.endswith('.npy'):
assert os.path.isfile(filename), filename assert os.path.isfile(filename), filename
return ParamRestore(np.load(filename, encoding='latin1').item()) return DictRestore(np.load(filename, encoding='latin1').item())
else: else:
return SaverRestore(filename) return SaverRestore(filename)
...@@ -119,7 +119,7 @@ class SessionUpdate(object): ...@@ -119,7 +119,7 @@ class SessionUpdate(object):
def dump_session_params(path): def dump_session_params(path):
""" """
Dump value of all TRAINABLE + MODEL variables to a dict, and save as Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy format (loadable by :class:`ParamRestore`). npy format (loadable by :class:`DictRestore`).
Args: Args:
path(str): the path to save the parameters. path(str): the path to save the parameters.
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: argtools.py # File: argtools.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import inspect import inspect
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment