Commit 73baa9ae authored by Yuxin Wu's avatar Yuxin Wu

get_all_checkpoints

parent 425f9d27
......@@ -407,6 +407,8 @@ _DEPRECATED_NAMES = set([
"get_model_loader",
# renamed items that should not appear in docs
'load_chkpt_vars',
'save_chkpt_vars',
'DumpTensor',
'DumpParamAsImage',
'get_nr_gpu',
......
# -*- coding: utf-8 -*-
# File: varmanip.py
import glob
import operator
import numpy as np
import os
import pprint
......@@ -12,7 +14,9 @@ from ..utils import logger
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params',
'load_chkpt_vars', 'save_chkpt_vars', 'get_checkpoint_path']
'load_chkpt_vars', 'save_chkpt_vars',
'load_checkpoint_vars', 'save_checkpoint_vars',
'get_checkpoint_path']
def get_savename_from_varname(
......@@ -146,19 +150,19 @@ def dump_session_params(path):
path(str): the file name to save the parameters. Must ends with npz.
"""
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
var = tfv1.get_collection(tfv1.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tfv1.get_collection(tfv1.GraphKeys.MODEL_VARIABLES))
# TODO dedup
assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!"
gvars = {k.name for k in tf.global_variables()}
gvars = {k.name for k in tfv1.global_variables()}
var = [v for v in var if v.name in gvars]
result = {}
for v in var:
result[v.name] = v.eval()
save_chkpt_vars(result, path)
save_checkpoint_vars(result, path)
def save_chkpt_vars(dic, path):
def save_checkpoint_vars(dic, path):
"""
Save variables in dic to path.
......@@ -174,13 +178,13 @@ def save_chkpt_vars(dic, path):
if path.endswith('.npz'):
np.savez_compressed(path, **dic)
else:
with tf.Graph().as_default(), \
tf.Session() as sess:
with tfv1.Graph().as_default(), \
tfv1.Session() as sess:
for k, v in six.iteritems(dic):
k = get_op_tensor_name(k)[0]
_ = tf.Variable(name=k, initial_value=v) # noqa
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
_ = tfv1.Variable(name=k, initial_value=v) # noqa
sess.run(tfv1.global_variables_initializer())
saver = tfv1.train.Saver()
saver.save(sess, path, write_meta_graph=False)
......@@ -197,7 +201,7 @@ def get_checkpoint_path(path):
path = os.path.join('.', path) # avoid #4921 and #6142
if os.path.basename(path) == 'checkpoint':
assert tfv1.gfile.Exists(path), path
path = tf.train.latest_checkpoint(os.path.dirname(path))
path = tfv1.train.latest_checkpoint(os.path.dirname(path))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
......@@ -214,7 +218,31 @@ def get_checkpoint_path(path):
return path
def load_chkpt_vars(path):
def get_all_checkpoints(dir: str, prefix: str = "model"):
"""
Get a sorted list of all checkpoints found in directory.
Args:
dir (str): checkpoint directory
prefix (str): common prefix among all checkpoints (without the final "-")
Returns:
list[(str, int)]: list of (name, step) sorted by step.
Name is a checkpoint handle that can be passed to
`tf.train.NewCheckpointReader` or :func:`load_checkpoint_vars`.
"""
def step_from_filename(name):
name = os.path.basename(name)
name = name[len(f"{prefix}-"):-len(".index")]
return int(name)
checkpoints = glob.glob(os.path.join(dir, "model-*.index"))
checkpoints = [(f, step_from_filename(f)) for f in checkpoints]
checkpoints = sorted(checkpoints, key=operator.itemgetter(1))
return checkpoints
def load_checkpoint_vars(path):
""" Load all variables from a checkpoint to a dict.
Args:
......@@ -257,3 +285,7 @@ def is_training_name(name):
if name.startswith('apply_gradients'):
return True
return False
load_chkpt_vars = load_checkpoint_vars
save_chkpt_vars = save_checkpoint_vars
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment