Commit 9c5fa4c4 authored by Yuxin Wu's avatar Yuxin Wu

Let ModelSaver follows the official Saver API.

parent 271b2bf1
......@@ -9,6 +9,7 @@ import glob
from .base import Callback
from ..utils import logger
from ..utils.develop import log_deprecated
from ..tfutils.common import get_tf_version_number
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......@@ -19,18 +20,26 @@ class ModelSaver(Callback):
Save the model every epoch.
"""
def __init__(self, keep_recent=10, keep_freq=0.5,
def __init__(self, max_to_keep=10,
keep_checkpoint_every_n_hours=0.5,
checkpoint_dir=None,
var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
var_collections=tf.GraphKeys.GLOBAL_VARIABLES,
keep_recent=None, keep_freq=None):
"""
Args:
keep_recent(int): see ``tf.train.Saver`` documentation.
keep_freq(int): see ``tf.train.Saver`` documentation.
max_to_keep, keep_checkpoint_every_n_hours(int): the same as in ``tf.train.Saver``.
checkpoint_dir (str): Defaults to ``logger.LOG_DIR``.
var_collections (str or list of str): collection of the variables (or list of collections) to save.
"""
self.keep_recent = keep_recent
self.keep_freq = keep_freq
self._max_to_keep = max_to_keep
self._keep_every_n_hours = keep_checkpoint_every_n_hours
if keep_recent is not None or keep_freq is not None:
log_deprecated("ModelSaver(keep_recent=, keep_freq=)", "Use max_to_keep and keep_checkpoint_every_n_hours!")
if keep_recent is not None:
self._max_to_keep = keep_recent
if keep_freq is not None:
self._keep_every_n_hours = keep_freq
if not isinstance(var_collections, list):
var_collections = [var_collections]
self.var_collections = var_collections
......@@ -48,14 +57,14 @@ class ModelSaver(Callback):
if get_tf_version_number() <= 1.1:
self.saver = tf.train.Saver(
var_list=vars,
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq,
max_to_keep=self._max_to_keep,
keep_checkpoint_every_n_hours=self._keep_every_n_hours,
write_version=tf.train.SaverDef.V2)
else:
self.saver = tf.train.Saver(
var_list=vars,
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq,
max_to_keep=self._max_to_keep,
keep_checkpoint_every_n_hours=self._keep_every_n_hours,
write_version=tf.train.SaverDef.V2,
save_relative_paths=True)
self.meta_graph_written = False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment