Commit 57f542de authored by Yuxin Wu's avatar Yuxin Wu

`load_checkpoint_vars` supports npz

parent 43a44c1d
......@@ -16,7 +16,7 @@ from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params',
'load_chkpt_vars', 'save_chkpt_vars',
'load_checkpoint_vars', 'save_checkpoint_vars',
'get_checkpoint_path']
'get_checkpoint_path', 'get_all_checkpoints']
def get_savename_from_varname(
......@@ -251,6 +251,10 @@ def load_checkpoint_vars(path):
Returns:
dict: a name:value dict
"""
if path.endswith(".npz"):
ret = dict(np.load(path))
ret = {get_op_tensor_name(k)[0]: v for k, v in ret.items()}
return ret
path = get_checkpoint_path(path)
reader = tfv1.train.NewCheckpointReader(path)
var_names = reader.get_variable_to_shape_map().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment