Commit 300d840a authored by Yuxin Wu's avatar Yuxin Wu

Support npz format

parent ef8d4e49
...@@ -360,6 +360,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -360,6 +360,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'remap_get_variable', 'remap_get_variable',
'freeze_get_variable', 'freeze_get_variable',
'Triggerable', 'Triggerable',
'dump_chkpt_vars',
'ParamRestore']: 'ParamRestore']:
return True return True
if name in ['get_data', 'size', 'reset_state']: if name in ['get_data', 'size', 'reset_state']:
......
...@@ -26,7 +26,7 @@ To inspect a checkpoint, the easiest tool is `tf.train.NewCheckpointReader`. Ple ...@@ -26,7 +26,7 @@ To inspect a checkpoint, the easiest tool is `tf.train.NewCheckpointReader`. Ple
expects a model path without the extension. expects a model path without the extension.
You can dump a cleaner version of the model (without unnecessary variables), using You can dump a cleaner version of the model (without unnecessary variables), using
`scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy format. `scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy/npz format.
The script expects a metagraph file which is also saved by `ModelSaver`. The script expects a metagraph file which is also saved by `ModelSaver`.
...@@ -48,10 +48,10 @@ Unmatched variables on both sides will be printed as a warning. ...@@ -48,10 +48,10 @@ Unmatched variables on both sides will be printed as a warning.
1. You can simply use `tf.stop_gradient` in your model code in some situations (e.g. to freeze first several layers). 1. You can simply use `tf.stop_gradient` in your model code in some situations (e.g. to freeze first several layers).
2. [varreplace.freeze_variables](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.varreplace.freeze_variables) can wrap some variables with `tf.stop_gradient`. 2. [varreplace.freeze_variables](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.varreplace.freeze_variables) can wrap some variables with `tf.stop_gradient`.
3. [ScaleGradient](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.gradproc.ScaleGradient) can be used to set the gradients of some variables to 0. 3. [ScaleGradient](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.gradproc.ScaleGradient) can be used to set the gradients of some variables to 0.
Note that the above methods only prevent variables being updated by SGD. Note that the above methods only prevent variables being updated by SGD.
Some variables may be updated by other means, Some variables may be updated by other means,
e.g., BatchNorm statistics are updated through the `UPDATE_OPS` collection and the [RunUpdateOps](http://tensorpack.readthedocs.io/en/latest/modules/callbacks.html#tensorpack.callbacks.RunUpdateOps) callback. e.g., BatchNorm statistics are updated through the `UPDATE_OPS` collection and the [RunUpdateOps](http://tensorpack.readthedocs.io/en/latest/modules/callbacks.html#tensorpack.callbacks.RunUpdateOps) callback.
...@@ -19,6 +19,8 @@ if __name__ == '__main__': ...@@ -19,6 +19,8 @@ if __name__ == '__main__':
if args.model.endswith('.npy'): if args.model.endswith('.npy'):
params = np.load(args.model).item() params = np.load(args.model).item()
elif args.model.endswith('.npz'):
params = dict(np.load(args.model))
else: else:
params = dump_chkpt_vars(args.model) params = dump_chkpt_vars(args.model)
logger.info("Variables in the model:") logger.info("Variables in the model:")
......
...@@ -9,7 +9,7 @@ import tensorflow as tf ...@@ -9,7 +9,7 @@ import tensorflow as tf
import imp import imp
from tensorpack import TowerContext, logger from tensorpack import TowerContext, logger
from tensorpack.tfutils import sessinit, varmanip from tensorpack.tfutils import sessinit, varmanip, get_model_loader
from tensorpack.graph_builder.input_source import PlaceholderInput from tensorpack.graph_builder.input_source import PlaceholderInput
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -34,17 +34,14 @@ with tf.Graph().as_default() as G: ...@@ -34,17 +34,14 @@ with tf.Graph().as_default() as G:
tf.train.import_meta_graph(args.meta) tf.train.import_meta_graph(args.meta)
# loading... # loading...
if args.model.endswith('.npy'): init = get_model_loader(args.model)
init = sessinit.DictRestore(np.load(args.model).item())
else:
init = sessinit.SaverRestore(args.model)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
init.init(sess) init.init(sess)
# dump ... # dump ...
with sess.as_default(): with sess.as_default():
if args.output.endswith('npy'): if args.output.endswith('npy') or args.output.endswith('npz'):
varmanip.dump_session_params(args.output) varmanip.dump_session_params(args.output)
else: else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
......
...@@ -256,12 +256,16 @@ def get_model_loader(filename): ...@@ -256,12 +256,16 @@ def get_model_loader(filename):
Get a corresponding model loader by looking at the file name. Get a corresponding model loader by looking at the file name.
Returns: Returns:
SessInit: either a :class:`DictRestore` (if name ends with 'npy') or SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise). :class:`SaverRestore` (otherwise).
""" """
if filename.endswith('.npy'): if filename.endswith('.npy'):
assert os.path.isfile(filename), filename assert os.path.isfile(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item()) return DictRestore(np.load(filename, encoding='latin1').item())
elif filename.endswith('.npz'):
assert os.path.isfile(filename), filename
obj = np.load(filename)
return DictRestore(dict(obj))
else: else:
return SaverRestore(filename) return SaverRestore(filename)
......
...@@ -12,6 +12,7 @@ from ..utils import logger ...@@ -12,6 +12,7 @@ from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars', __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'load_chkpt_vars',
# 'get_savename_from_varname', 'is_training_name', # 'get_savename_from_varname', 'is_training_name',
'get_checkpoint_path'] 'get_checkpoint_path']
...@@ -112,10 +113,10 @@ class SessionUpdate(object): ...@@ -112,10 +113,10 @@ class SessionUpdate(object):
def dump_session_params(path): def dump_session_params(path):
""" """
Dump value of all TRAINABLE + MODEL variables to a dict, and save as Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy format (loadable by :class:`DictRestore`). npy/npz format (loadable by :class:`DictRestore`).
Args: Args:
path(str): the path to save the parameters. path(str): the file name to save the parameters. Must ends with npy or npz.
""" """
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
...@@ -127,7 +128,12 @@ def dump_session_params(path): ...@@ -127,7 +128,12 @@ def dump_session_params(path):
logger.info("Variables to save to {}:".format(path)) logger.info("Variables to save to {}:".format(path))
keys = sorted(list(result.keys())) keys = sorted(list(result.keys()))
logger.info(pprint.pformat(keys)) logger.info(pprint.pformat(keys))
np.save(path, result) if path.endswith('.npy'):
np.save(path, result)
elif path.endswith('.npz'):
np.savez_compressed(path, **result)
else:
raise ValueError("Don't know which format to use for {}".format(path))
def get_checkpoint_path(model_path): def get_checkpoint_path(model_path):
...@@ -160,7 +166,7 @@ def get_checkpoint_path(model_path): ...@@ -160,7 +166,7 @@ def get_checkpoint_path(model_path):
return model_path return model_path
def dump_chkpt_vars(model_path): def load_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict. """ Dump all variables from a checkpoint to a dict.
Args: Args:
...@@ -177,6 +183,9 @@ def dump_chkpt_vars(model_path): ...@@ -177,6 +183,9 @@ def dump_chkpt_vars(model_path):
result[n] = reader.get_tensor(n) result[n] = reader.get_tensor(n)
return result return result
def dump_chkpt_vars(model_path):
logger.warn("dump_chkpt_vars was renamed to load_chkpt_vars!")
return load_chkpt_vars(model_path)
def is_training_name(name): def is_training_name(name):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment