Commit fded4f40 authored by Yuxin Wu's avatar Yuxin Wu

Check global_step in MinSaver (fix #966)

parent 5772d5fd
...@@ -115,10 +115,12 @@ if __name__ == '__main__': ...@@ -115,10 +115,12 @@ if __name__ == '__main__':
data=FeedInput(dataset_train), data=FeedInput(dataset_train),
callbacks=[ callbacks=[
ModelSaver(), # save the model after every epoch ModelSaver(), # save the model after every epoch
MaxSaver('validation_accuracy'), # save the model with highest accuracy (prefix 'validation_')
InferenceRunner( # run inference(for validation) after every epoch InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation dataset_test, # the DataFlow instance used for validation
ScalarStats(['cross_entropy_loss', 'accuracy'])), ScalarStats( # produce `val_accuracy` and `val_cross_entropy_loss`
['cross_entropy_loss', 'accuracy'], prefix='val')),
# MaxSaver has to come after InferenceRunner
MaxSaver('val_accuracy'), # save the model with highest accuracy
], ],
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
max_epoch=100, max_epoch=100,
......
...@@ -102,58 +102,68 @@ class MinSaver(Callback): ...@@ -102,58 +102,68 @@ class MinSaver(Callback):
MinSaver('val-error') MinSaver('val-error')
Note: Note:
It assumes that :class:`ModelSaver` is used with the same ``checkpoint_dir`` 1. It assumes that :class:`ModelSaver` is used with the same ``checkpoint_dir``
and appears earlier in the callback list. and appears earlier in the callback list.
The default for both :class:`ModelSaver` and :class:`MinSaver` The default for both :class:`ModelSaver` and :class:`MinSaver`
is ``checkpoint_dir=logger.get_logger_dir()`` is ``checkpoint_dir=logger.get_logger_dir()``
2. Callbacks are executed in the order they are defined. Therefore you'd want to
use this callback after the callback (e.g. InferenceRunner) that produces the statistics.
""" """
self.monitor_stat = monitor_stat self.monitor_stat = monitor_stat
self.reverse = reverse self.reverse = reverse
self.filename = filename self.filename = filename
self.min = None self.best = None
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
if self.checkpoint_dir is None: if self.checkpoint_dir is None:
self.checkpoint_dir = logger.get_logger_dir() self.checkpoint_dir = logger.get_logger_dir()
def _get_stat(self): def _get_stat(self):
try: try:
v = self.trainer.monitors.get_latest(self.monitor_stat) v = self.trainer.monitors.get_history(self.monitor_stat)[-1]
except KeyError: except (KeyError, IndexError):
v = None v = None, None
return v return v
def _need_save(self):
v = self._get_stat()
if not v:
return False
return v > self.min if self.reverse else v < self.min
def _trigger(self): def _trigger(self):
if self.min is None or self._need_save(): curr_step, curr_val = self._get_stat()
self.min = self._get_stat() if curr_step is None:
if self.min: return
self._save()
if self.best is None or (curr_val > self.best[1] if self.reverse else curr_val < self.best[1]):
self.best = (curr_step, curr_val)
self._save()
def _save(self): def _save(self):
ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
if ckpt is None: if ckpt is None:
raise RuntimeError( raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?") "[MinSaver] Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = ckpt.model_checkpoint_path path = ckpt.model_checkpoint_path
extreme_name = 'maximum' if self.reverse else 'minimum'
if not path.endswith(str(self.best[0])):
logger.warn("[MinSaver] New {} '{}' found at global_step={}, but the latest checkpoint is {}.".format(
extreme_name, self.monitor_stat, self.best[0], path
))
logger.warn("MinSaver will do nothing this time. "
"The callbacks may have inconsistent frequency or wrong order.")
return
newname = os.path.join(self.checkpoint_dir, newname = os.path.join(self.checkpoint_dir,
self.filename or self.filename or
('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat)) ('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat))
files_to_copy = tf.gfile.Glob(path + '*') files_to_copy = tf.gfile.Glob(path + '*')
for file_to_copy in files_to_copy: for file_to_copy in files_to_copy:
tf.gfile.Copy(file_to_copy, file_to_copy.replace(path, newname), overwrite=True) tf.gfile.Copy(file_to_copy, file_to_copy.replace(path, newname), overwrite=True)
logger.info("Model with {} '{}' saved.".format( logger.info("Model at global_step={} with {} {}={:.5g} saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat)) self.best[0], extreme_name, self.monitor_stat, self.best[1]))
class MaxSaver(MinSaver): class MaxSaver(MinSaver):
""" """
Separately save the model with maximum value of some statistics. Separately save the model with maximum value of some statistics.
See docs of :class:`MinSaver` for details.
""" """
def __init__(self, monitor_stat, filename=None, checkpoint_dir=None): def __init__(self, monitor_stat, filename=None, checkpoint_dir=None):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment