Commit 300d840a authored by Yuxin Wu's avatar Yuxin Wu

Support npz format

parent ef8d4e49
......@@ -360,6 +360,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'remap_get_variable',
'freeze_get_variable',
'Triggerable',
'dump_chkpt_vars',
'ParamRestore']:
return True
if name in ['get_data', 'size', 'reset_state']:
......
......@@ -26,7 +26,7 @@ To inspect a checkpoint, the easiest tool is `tf.train.NewCheckpointReader`. Ple
expects a model path without the extension.
You can dump a cleaner version of the model (without unnecessary variables), using
`scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy format.
`scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy/npz format.
The script expects a metagraph file which is also saved by `ModelSaver`.
......
......@@ -19,6 +19,8 @@ if __name__ == '__main__':
if args.model.endswith('.npy'):
params = np.load(args.model).item()
elif args.model.endswith('.npz'):
params = dict(np.load(args.model))
else:
params = dump_chkpt_vars(args.model)
logger.info("Variables in the model:")
......
......@@ -9,7 +9,7 @@ import tensorflow as tf
import imp
from tensorpack import TowerContext, logger
from tensorpack.tfutils import sessinit, varmanip
from tensorpack.tfutils import sessinit, varmanip, get_model_loader
from tensorpack.graph_builder.input_source import PlaceholderInput
parser = argparse.ArgumentParser()
......@@ -34,17 +34,14 @@ with tf.Graph().as_default() as G:
tf.train.import_meta_graph(args.meta)
# loading...
if args.model.endswith('.npy'):
init = sessinit.DictRestore(np.load(args.model).item())
else:
init = sessinit.SaverRestore(args.model)
init = get_model_loader(args.model)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer())
init.init(sess)
# dump ...
with sess.as_default():
if args.output.endswith('npy'):
if args.output.endswith('npy') or args.output.endswith('npz'):
varmanip.dump_session_params(args.output)
else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
......
......@@ -256,12 +256,16 @@ def get_model_loader(filename):
Get a corresponding model loader by looking at the file name.
Returns:
SessInit: either a :class:`DictRestore` (if name ends with 'npy') or
SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise).
"""
if filename.endswith('.npy'):
assert os.path.isfile(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item())
elif filename.endswith('.npz'):
assert os.path.isfile(filename), filename
obj = np.load(filename)
return DictRestore(dict(obj))
else:
return SaverRestore(filename)
......
......@@ -12,6 +12,7 @@ from ..utils import logger
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'load_chkpt_vars',
# 'get_savename_from_varname', 'is_training_name',
'get_checkpoint_path']
......@@ -112,10 +113,10 @@ class SessionUpdate(object):
def dump_session_params(path):
"""
Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy format (loadable by :class:`DictRestore`).
npy/npz format (loadable by :class:`DictRestore`).
Args:
path(str): the path to save the parameters.
path(str): the file name to save the parameters. Must ends with npy or npz.
"""
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
......@@ -127,7 +128,12 @@ def dump_session_params(path):
logger.info("Variables to save to {}:".format(path))
keys = sorted(list(result.keys()))
logger.info(pprint.pformat(keys))
if path.endswith('.npy'):
np.save(path, result)
elif path.endswith('.npz'):
np.savez_compressed(path, **result)
else:
raise ValueError("Don't know which format to use for {}".format(path))
def get_checkpoint_path(model_path):
......@@ -160,7 +166,7 @@ def get_checkpoint_path(model_path):
return model_path
def dump_chkpt_vars(model_path):
def load_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict.
Args:
......@@ -177,6 +183,9 @@ def dump_chkpt_vars(model_path):
result[n] = reader.get_tensor(n)
return result
def dump_chkpt_vars(model_path):
logger.warn("dump_chkpt_vars was renamed to load_chkpt_vars!")
return load_chkpt_vars(model_path)
def is_training_name(name):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment