Commit 38818ff2 authored by Maciej Jaśkowski's avatar Maciej Jaśkowski Committed by Yuxin Wu

MinSaver takes into account checkpoint_dir. (#520)

parent f7729086
......@@ -89,31 +89,33 @@ class MinSaver(Callback):
"""
Separately save the model with minimum value of some statistics.
"""
def __init__(self, monitor_stat, reverse=False, filename=None):
def __init__(self, monitor_stat, reverse=False, filename=None, checkpoint_dir=None):
"""
Args:
monitor_stat(str): the name of the statistics.
reverse (bool): if True, will save the maximum.
filename (str): the name for the saved model.
Defaults to ``min-{monitor_stat}.tfmodel``.
Example:
Save the model with minimum validation error to
"min-val-error.tfmodel":
.. code-block:: python
MinSaver('val-error')
Note:
It assumes that :class:`ModelSaver` is used with
``checkpoint_dir=logger.get_logger_dir()`` (the default). And it will save
the same ``checkpoint_dir``. And it will save
the model to that directory as well.
The default for both :class:`ModelSaver` and :class:`MinSaver`
is ``checkpoint_dir=logger.get_logger_dir()``
"""
self.monitor_stat = monitor_stat
self.reverse = reverse
self.filename = filename
self.min = None
self.checkpoint_dir = checkpoint_dir
if self.checkpoint_dir is None:
self.checkpoint_dir = logger.get_logger_dir()
def _get_stat(self):
try:
......@@ -135,13 +137,13 @@ class MinSaver(Callback):
self._save()
def _save(self):
ckpt = tf.train.get_checkpoint_state(logger.get_logger_dir())
ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = ckpt.model_checkpoint_path
newname = os.path.join(logger.get_logger_dir(),
newname = os.path.join(self.checkpoint_dir,
self.filename or
('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat))
files_to_copy = tf.gfile.Glob(path + '*')
......@@ -155,11 +157,11 @@ class MaxSaver(MinSaver):
"""
Separately save the model with maximum value of some statistics.
"""
def __init__(self, monitor_stat, filename=None):
def __init__(self, monitor_stat, filename=None, checkpoint_dir=None):
"""
Args:
monitor_stat(str): the name of the statistics.
filename (str): the name for the saved model.
Defaults to ``max-{monitor_stat}.tfmodel``.
"""
super(MaxSaver, self).__init__(monitor_stat, True, filename=filename)
super(MaxSaver, self).__init__(monitor_stat, True, filename=filename, checkpoint_dir=checkpoint_dir)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment