Commit 57f542de authored by Yuxin Wu's avatar Yuxin Wu

`load_checkpoint_vars` supports npz

parent 43a44c1d
...@@ -16,7 +16,7 @@ from .common import get_op_tensor_name ...@@ -16,7 +16,7 @@ from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', __all__ = ['SessionUpdate', 'dump_session_params',
'load_chkpt_vars', 'save_chkpt_vars', 'load_chkpt_vars', 'save_chkpt_vars',
'load_checkpoint_vars', 'save_checkpoint_vars', 'load_checkpoint_vars', 'save_checkpoint_vars',
'get_checkpoint_path'] 'get_checkpoint_path', 'get_all_checkpoints']
def get_savename_from_varname( def get_savename_from_varname(
...@@ -251,6 +251,10 @@ def load_checkpoint_vars(path): ...@@ -251,6 +251,10 @@ def load_checkpoint_vars(path):
Returns: Returns:
dict: a name:value dict dict: a name:value dict
""" """
if path.endswith(".npz"):
ret = dict(np.load(path))
ret = {get_op_tensor_name(k)[0]: v for k, v in ret.items()}
return ret
path = get_checkpoint_path(path) path = get_checkpoint_path(path)
reader = tfv1.train.NewCheckpointReader(path) reader = tfv1.train.NewCheckpointReader(path)
var_names = reader.get_variable_to_shape_map().keys() var_names = reader.get_variable_to_shape_map().keys()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment