Commit 300d840a authored by Yuxin Wu's avatar Yuxin Wu

Support npz format

parent ef8d4e49
...@@ -360,6 +360,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -360,6 +360,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'remap_get_variable', 'remap_get_variable',
'freeze_get_variable', 'freeze_get_variable',
'Triggerable', 'Triggerable',
'dump_chkpt_vars',
'ParamRestore']: 'ParamRestore']:
return True return True
if name in ['get_data', 'size', 'reset_state']: if name in ['get_data', 'size', 'reset_state']:
......
...@@ -26,7 +26,7 @@ To inspect a checkpoint, the easiest tool is `tf.train.NewCheckpointReader`. Ple ...@@ -26,7 +26,7 @@ To inspect a checkpoint, the easiest tool is `tf.train.NewCheckpointReader`. Ple
expects a model path without the extension. expects a model path without the extension.
You can dump a cleaner version of the model (without unnecessary variables), using You can dump a cleaner version of the model (without unnecessary variables), using
`scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy format. `scripts/dump-model-params.py`, as a simple `var-name: value` dict saved in npy/npz format.
The script expects a metagraph file which is also saved by `ModelSaver`. The script expects a metagraph file which is also saved by `ModelSaver`.
......
...@@ -19,6 +19,8 @@ if __name__ == '__main__': ...@@ -19,6 +19,8 @@ if __name__ == '__main__':
if args.model.endswith('.npy'): if args.model.endswith('.npy'):
params = np.load(args.model).item() params = np.load(args.model).item()
elif args.model.endswith('.npz'):
params = dict(np.load(args.model))
else: else:
params = dump_chkpt_vars(args.model) params = dump_chkpt_vars(args.model)
logger.info("Variables in the model:") logger.info("Variables in the model:")
......
...@@ -9,7 +9,7 @@ import tensorflow as tf ...@@ -9,7 +9,7 @@ import tensorflow as tf
import imp import imp
from tensorpack import TowerContext, logger from tensorpack import TowerContext, logger
from tensorpack.tfutils import sessinit, varmanip from tensorpack.tfutils import sessinit, varmanip, get_model_loader
from tensorpack.graph_builder.input_source import PlaceholderInput from tensorpack.graph_builder.input_source import PlaceholderInput
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -34,17 +34,14 @@ with tf.Graph().as_default() as G: ...@@ -34,17 +34,14 @@ with tf.Graph().as_default() as G:
tf.train.import_meta_graph(args.meta) tf.train.import_meta_graph(args.meta)
# loading... # loading...
if args.model.endswith('.npy'): init = get_model_loader(args.model)
init = sessinit.DictRestore(np.load(args.model).item())
else:
init = sessinit.SaverRestore(args.model)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
init.init(sess) init.init(sess)
# dump ... # dump ...
with sess.as_default(): with sess.as_default():
if args.output.endswith('npy'): if args.output.endswith('npy') or args.output.endswith('npz'):
varmanip.dump_session_params(args.output) varmanip.dump_session_params(args.output)
else: else:
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
......
...@@ -256,12 +256,16 @@ def get_model_loader(filename): ...@@ -256,12 +256,16 @@ def get_model_loader(filename):
Get a corresponding model loader by looking at the file name. Get a corresponding model loader by looking at the file name.
Returns: Returns:
SessInit: either a :class:`DictRestore` (if name ends with 'npy') or SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise). :class:`SaverRestore` (otherwise).
""" """
if filename.endswith('.npy'): if filename.endswith('.npy'):
assert os.path.isfile(filename), filename assert os.path.isfile(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item()) return DictRestore(np.load(filename, encoding='latin1').item())
elif filename.endswith('.npz'):
assert os.path.isfile(filename), filename
obj = np.load(filename)
return DictRestore(dict(obj))
else: else:
return SaverRestore(filename) return SaverRestore(filename)
......
...@@ -12,6 +12,7 @@ from ..utils import logger ...@@ -12,6 +12,7 @@ from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars', __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars',
'load_chkpt_vars',
# 'get_savename_from_varname', 'is_training_name', # 'get_savename_from_varname', 'is_training_name',
'get_checkpoint_path'] 'get_checkpoint_path']
...@@ -112,10 +113,10 @@ class SessionUpdate(object): ...@@ -112,10 +113,10 @@ class SessionUpdate(object):
def dump_session_params(path): def dump_session_params(path):
""" """
Dump value of all TRAINABLE + MODEL variables to a dict, and save as Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy format (loadable by :class:`DictRestore`). npy/npz format (loadable by :class:`DictRestore`).
Args: Args:
path(str): the path to save the parameters. path(str): the file name to save the parameters. Must ends with npy or npz.
""" """
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
...@@ -127,7 +128,12 @@ def dump_session_params(path): ...@@ -127,7 +128,12 @@ def dump_session_params(path):
logger.info("Variables to save to {}:".format(path)) logger.info("Variables to save to {}:".format(path))
keys = sorted(list(result.keys())) keys = sorted(list(result.keys()))
logger.info(pprint.pformat(keys)) logger.info(pprint.pformat(keys))
if path.endswith('.npy'):
np.save(path, result) np.save(path, result)
elif path.endswith('.npz'):
np.savez_compressed(path, **result)
else:
raise ValueError("Don't know which format to use for {}".format(path))
def get_checkpoint_path(model_path): def get_checkpoint_path(model_path):
...@@ -160,7 +166,7 @@ def get_checkpoint_path(model_path): ...@@ -160,7 +166,7 @@ def get_checkpoint_path(model_path):
return model_path return model_path
def dump_chkpt_vars(model_path): def load_chkpt_vars(model_path):
""" Dump all variables from a checkpoint to a dict. """ Dump all variables from a checkpoint to a dict.
Args: Args:
...@@ -177,6 +183,9 @@ def dump_chkpt_vars(model_path): ...@@ -177,6 +183,9 @@ def dump_chkpt_vars(model_path):
result[n] = reader.get_tensor(n) result[n] = reader.get_tensor(n)
return result return result
def dump_chkpt_vars(model_path):
logger.warn("dump_chkpt_vars was renamed to load_chkpt_vars!")
return load_chkpt_vars(model_path)
def is_training_name(name): def is_training_name(name):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment