Commit 38818ff2 authored by Maciej Jaśkowski's avatar Maciej Jaśkowski Committed by Yuxin Wu

MinSaver takes into account checkpoint_dir. (#520)

parent f7729086
...@@ -89,31 +89,33 @@ class MinSaver(Callback): ...@@ -89,31 +89,33 @@ class MinSaver(Callback):
""" """
Separately save the model with minimum value of some statistics. Separately save the model with minimum value of some statistics.
""" """
def __init__(self, monitor_stat, reverse=False, filename=None): def __init__(self, monitor_stat, reverse=False, filename=None, checkpoint_dir=None):
""" """
Args: Args:
monitor_stat(str): the name of the statistics. monitor_stat(str): the name of the statistics.
reverse (bool): if True, will save the maximum. reverse (bool): if True, will save the maximum.
filename (str): the name for the saved model. filename (str): the name for the saved model.
Defaults to ``min-{monitor_stat}.tfmodel``. Defaults to ``min-{monitor_stat}.tfmodel``.
Example: Example:
Save the model with minimum validation error to Save the model with minimum validation error to
"min-val-error.tfmodel": "min-val-error.tfmodel":
.. code-block:: python .. code-block:: python
MinSaver('val-error') MinSaver('val-error')
Note: Note:
It assumes that :class:`ModelSaver` is used with It assumes that :class:`ModelSaver` is used with
``checkpoint_dir=logger.get_logger_dir()`` (the default). And it will save the same ``checkpoint_dir``. And it will save
the model to that directory as well. the model to that directory as well.
The default for both :class:`ModelSaver` and :class:`MinSaver`
is ``checkpoint_dir=logger.get_logger_dir()``
""" """
self.monitor_stat = monitor_stat self.monitor_stat = monitor_stat
self.reverse = reverse self.reverse = reverse
self.filename = filename self.filename = filename
self.min = None self.min = None
self.checkpoint_dir = checkpoint_dir
if self.checkpoint_dir is None:
self.checkpoint_dir = logger.get_logger_dir()
def _get_stat(self): def _get_stat(self):
try: try:
...@@ -135,13 +137,13 @@ class MinSaver(Callback): ...@@ -135,13 +137,13 @@ class MinSaver(Callback):
self._save() self._save()
def _save(self): def _save(self):
ckpt = tf.train.get_checkpoint_state(logger.get_logger_dir()) ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
if ckpt is None: if ckpt is None:
raise RuntimeError( raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?") "Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = ckpt.model_checkpoint_path path = ckpt.model_checkpoint_path
newname = os.path.join(logger.get_logger_dir(), newname = os.path.join(self.checkpoint_dir,
self.filename or self.filename or
('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat)) ('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat))
files_to_copy = tf.gfile.Glob(path + '*') files_to_copy = tf.gfile.Glob(path + '*')
...@@ -155,11 +157,11 @@ class MaxSaver(MinSaver): ...@@ -155,11 +157,11 @@ class MaxSaver(MinSaver):
""" """
Separately save the model with maximum value of some statistics. Separately save the model with maximum value of some statistics.
""" """
def __init__(self, monitor_stat, filename=None): def __init__(self, monitor_stat, filename=None, checkpoint_dir=None):
""" """
Args: Args:
monitor_stat(str): the name of the statistics. monitor_stat(str): the name of the statistics.
filename (str): the name for the saved model. filename (str): the name for the saved model.
Defaults to ``max-{monitor_stat}.tfmodel``. Defaults to ``max-{monitor_stat}.tfmodel``.
""" """
super(MaxSaver, self).__init__(monitor_stat, True, filename=filename) super(MaxSaver, self).__init__(monitor_stat, True, filename=filename, checkpoint_dir=checkpoint_dir)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment