Commit 300d840a authored by Yuxin Wu's avatar Yuxin Wu

Support npz format

parent ef8d4e49
......@@ -360,6 +360,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'remap_get_variable',
'freeze_get_variable',
'Triggerable',
'dump_chkpt_vars',
'ParamRestore']:
return True
if name in ['get_data', 'size', 'reset_state']:
......
......@@ -26,7 +26,7 @@ To inspect a checkpoint, the easiest tool is `tf.train.NewCheckpointReader`. Ple
expects a model path without the extension.
You can dump a cleaner version of the model (without unnecessary variables), using
`scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy format.
`scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy/npz format.
The script expects a metagraph file which is also saved by `ModelSaver`.
......@@ -48,10 +48,10 @@ Unmatched variables on both sides will be printed as a warning.
1. You can simply use `tf.stop_gradient` in your model code in some situations (e.g. to freeze first several layers).
2. [varreplace.freeze_variables](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.varreplace.freeze_variables) can wrap some variables with `tf.stop_gradient`.
2. [varreplace.freeze_variables](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.varreplace.freeze_variables) can wrap some variables with `tf.stop_gradient`.
3. [ScaleGradient](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.gradproc.ScaleGradient) can be used to set the gradients of some variables to 0.
Note that the above methods only prevent variables being updated by SGD.
Some variables may be updated by other means,
Some variables may be updated by other means,
e.g., BatchNorm statistics are updated through the `UPDATE_OPS` collection and the [RunUpdateOps](http://tensorpack.readthedocs.io/en/latest/modules/callbacks.html#tensorpack.callbacks.RunUpdateOps) callback.
......@@ -19,6 +19,8 @@ if __name__ == '__main__':
if args.model.endswith('.npy'):
params = np.load(args.model).item()
elif args.model.endswith('.npz'):
params = dict(np.load(args.model))
else:
params = dump_chkpt_vars(args.model)
logger.info("Variables in the model:")
......
......@@ -9,7 +9,7 @@ import tensorflow as tf
import imp
from tensorpack import TowerContext, logger
from tensorpack.tfutils import sessinit, varmanip
from tensorpack.tfutils import sessinit, varmanip, get_model_loader
from tensorpack.graph_builder.input_source import PlaceholderInput
parser = argparse.ArgumentParser()
......@@ -34,17 +34,14 @@ with tf.Graph().as_default() as G:
tf.train.import_meta_graph(args.meta)
# loading...
if args.model.endswith('.npy'):
init = sessinit.DictRestore(np.load(args.model).item())
else:
init = sessinit.SaverRestore(args.model)
init = get_model_loader(args.model)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer())
init.init(sess)
# dump ...
with sess.as_default():
if args.output.endswith('npy'):
if args.output.endswith('npy') or args.output.endswith('npz'):
varmanip.dump_session_params(args.output)
else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
......
......@@ -256,12 +256,16 @@ def get_model_loader(filename):
Get a corresponding model loader by looking at the file name.
Returns:
SessInit: either a :class:`DictRestore` (if name ends with 'npy') or
SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise).
"""
if filename.endswith('.npy'):
assert os.path.isfile(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item())
elif filename.endswith('.npz'):
assert os.path.isfile(filename), filename
obj = np.load(filename)
return DictRestore(dict(obj))
else:
return SaverRestore(filename)
......
......@@ -12,6 +12,7 @@ from ..utils import logger
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'load_chkpt_vars',
# 'get_savename_from_varname', 'is_training_name',
'get_checkpoint_path']
......@@ -112,10 +113,10 @@ class SessionUpdate(object):
def dump_session_params(path):
"""
Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy format (loadable by :class:`DictRestore`).
npy/npz format (loadable by :class:`DictRestore`).
Args:
path(str): the path to save the parameters.
path(str): the file name to save the parameters. Must ends with npy or npz.
"""
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
......@@ -127,7 +128,12 @@ def dump_session_params(path):
logger.info("Variables to save to {}:".format(path))
keys = sorted(list(result.keys()))
logger.info(pprint.pformat(keys))
np.save(path, result)
if path.endswith('.npy'):
np.save(path, result)
elif path.endswith('.npz'):
np.savez_compressed(path, **result)
else:
raise ValueError("Don't know which format to use for {}".format(path))
def get_checkpoint_path(model_path):
......@@ -160,7 +166,7 @@ def get_checkpoint_path(model_path):
return model_path
def dump_chkpt_vars(model_path):
def load_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict.
Args:
......@@ -177,6 +183,9 @@ def dump_chkpt_vars(model_path):
result[n] = reader.get_tensor(n)
return result
def dump_chkpt_vars(model_path):
logger.warn("dump_chkpt_vars was renamed to load_chkpt_vars!")
return load_chkpt_vars(model_path)
def is_training_name(name):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment