Commit 73baa9ae authored by Yuxin Wu's avatar Yuxin Wu

get_all_checkpoints

parent 425f9d27
...@@ -407,6 +407,8 @@ _DEPRECATED_NAMES = set([ ...@@ -407,6 +407,8 @@ _DEPRECATED_NAMES = set([
"get_model_loader", "get_model_loader",
# renamed items that should not appear in docs # renamed items that should not appear in docs
'load_chkpt_vars',
'save_chkpt_vars',
'DumpTensor', 'DumpTensor',
'DumpParamAsImage', 'DumpParamAsImage',
'get_nr_gpu', 'get_nr_gpu',
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: varmanip.py # File: varmanip.py
import glob
import operator
import numpy as np import numpy as np
import os import os
import pprint import pprint
...@@ -12,7 +14,9 @@ from ..utils import logger ...@@ -12,7 +14,9 @@ from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params', __all__ = ['SessionUpdate', 'dump_session_params',
'load_chkpt_vars', 'save_chkpt_vars', 'get_checkpoint_path'] 'load_chkpt_vars', 'save_chkpt_vars',
'load_checkpoint_vars', 'save_checkpoint_vars',
'get_checkpoint_path']
def get_savename_from_varname( def get_savename_from_varname(
...@@ -146,19 +150,19 @@ def dump_session_params(path): ...@@ -146,19 +150,19 @@ def dump_session_params(path):
path(str): the file name to save the parameters. Must ends with npz. path(str): the file name to save the parameters. Must ends with npz.
""" """
# save variables that are GLOBAL, and either TRAINABLE or MODEL # save variables that are GLOBAL, and either TRAINABLE or MODEL
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tfv1.get_collection(tfv1.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) var.extend(tfv1.get_collection(tfv1.GraphKeys.MODEL_VARIABLES))
# TODO dedup # TODO dedup
assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!" assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!"
gvars = {k.name for k in tf.global_variables()} gvars = {k.name for k in tfv1.global_variables()}
var = [v for v in var if v.name in gvars] var = [v for v in var if v.name in gvars]
result = {} result = {}
for v in var: for v in var:
result[v.name] = v.eval() result[v.name] = v.eval()
save_chkpt_vars(result, path) save_checkpoint_vars(result, path)
def save_chkpt_vars(dic, path): def save_checkpoint_vars(dic, path):
""" """
Save variables in dic to path. Save variables in dic to path.
...@@ -174,13 +178,13 @@ def save_chkpt_vars(dic, path): ...@@ -174,13 +178,13 @@ def save_chkpt_vars(dic, path):
if path.endswith('.npz'): if path.endswith('.npz'):
np.savez_compressed(path, **dic) np.savez_compressed(path, **dic)
else: else:
with tf.Graph().as_default(), \ with tfv1.Graph().as_default(), \
tf.Session() as sess: tfv1.Session() as sess:
for k, v in six.iteritems(dic): for k, v in six.iteritems(dic):
k = get_op_tensor_name(k)[0] k = get_op_tensor_name(k)[0]
_ = tf.Variable(name=k, initial_value=v) # noqa _ = tfv1.Variable(name=k, initial_value=v) # noqa
sess.run(tf.global_variables_initializer()) sess.run(tfv1.global_variables_initializer())
saver = tf.train.Saver() saver = tfv1.train.Saver()
saver.save(sess, path, write_meta_graph=False) saver.save(sess, path, write_meta_graph=False)
...@@ -197,7 +201,7 @@ def get_checkpoint_path(path): ...@@ -197,7 +201,7 @@ def get_checkpoint_path(path):
path = os.path.join('.', path) # avoid #4921 and #6142 path = os.path.join('.', path) # avoid #4921 and #6142
if os.path.basename(path) == 'checkpoint': if os.path.basename(path) == 'checkpoint':
assert tfv1.gfile.Exists(path), path assert tfv1.gfile.Exists(path), path
path = tf.train.latest_checkpoint(os.path.dirname(path)) path = tfv1.train.latest_checkpoint(os.path.dirname(path))
# to be consistent with either v1 or v2 # to be consistent with either v1 or v2
# fix paths if provided a wrong one # fix paths if provided a wrong one
...@@ -214,7 +218,31 @@ def get_checkpoint_path(path): ...@@ -214,7 +218,31 @@ def get_checkpoint_path(path):
return path return path
def load_chkpt_vars(path): def get_all_checkpoints(dir: str, prefix: str = "model"):
"""
Get a sorted list of all checkpoints found in directory.
Args:
dir (str): checkpoint directory
prefix (str): common prefix among all checkpoints (without the final "-")
Returns:
list[(str, int)]: list of (name, step) sorted by step.
Name is a checkpoint handle that can be passed to
`tf.train.NewCheckpointReader` or :func:`load_checkpoint_vars`.
"""
def step_from_filename(name):
name = os.path.basename(name)
name = name[len(f"{prefix}-"):-len(".index")]
return int(name)
checkpoints = glob.glob(os.path.join(dir, "model-*.index"))
checkpoints = [(f, step_from_filename(f)) for f in checkpoints]
checkpoints = sorted(checkpoints, key=operator.itemgetter(1))
return checkpoints
def load_checkpoint_vars(path):
""" Load all variables from a checkpoint to a dict. """ Load all variables from a checkpoint to a dict.
Args: Args:
...@@ -257,3 +285,7 @@ def is_training_name(name): ...@@ -257,3 +285,7 @@ def is_training_name(name):
if name.startswith('apply_gradients'): if name.startswith('apply_gradients'):
return True return True
return False return False
load_chkpt_vars = load_checkpoint_vars
save_chkpt_vars = save_checkpoint_vars
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment