Commit fded4f40 authored by Yuxin Wu's avatar Yuxin Wu

Check global_step in MinSaver (fix #966)

parent 5772d5fd
......@@ -115,10 +115,12 @@ if __name__ == '__main__':
data=FeedInput(dataset_train),
callbacks=[
ModelSaver(), # save the model after every epoch
MaxSaver('validation_accuracy'), # save the model with highest accuracy (prefix 'validation_')
InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation
ScalarStats(['cross_entropy_loss', 'accuracy'])),
ScalarStats( # produce `val_accuracy` and `val_cross_entropy_loss`
['cross_entropy_loss', 'accuracy'], prefix='val')),
# MaxSaver has to come after InferenceRunner
MaxSaver('val_accuracy'), # save the model with highest accuracy
],
steps_per_epoch=steps_per_epoch,
max_epoch=100,
......
......@@ -102,58 +102,68 @@ class MinSaver(Callback):
MinSaver('val-error')
Note:
It assumes that :class:`ModelSaver` is used with the same ``checkpoint_dir``
and appears earlier in the callback list.
The default for both :class:`ModelSaver` and :class:`MinSaver`
is ``checkpoint_dir=logger.get_logger_dir()``
1. It assumes that :class:`ModelSaver` is used with the same ``checkpoint_dir``
and appears earlier in the callback list.
The default for both :class:`ModelSaver` and :class:`MinSaver`
is ``checkpoint_dir=logger.get_logger_dir()``
2. Callbacks are executed in the order they are defined. Therefore you'd want to
use this callback after the callback (e.g. InferenceRunner) that produces the statistics.
"""
self.monitor_stat = monitor_stat
self.reverse = reverse
self.filename = filename
self.min = None
self.best = None
self.checkpoint_dir = checkpoint_dir
if self.checkpoint_dir is None:
self.checkpoint_dir = logger.get_logger_dir()
def _get_stat(self):
try:
v = self.trainer.monitors.get_latest(self.monitor_stat)
except KeyError:
v = None
v = self.trainer.monitors.get_history(self.monitor_stat)[-1]
except (KeyError, IndexError):
v = None, None
return v
def _need_save(self):
v = self._get_stat()
if not v:
return False
return v > self.min if self.reverse else v < self.min
def _trigger(self):
if self.min is None or self._need_save():
self.min = self._get_stat()
if self.min:
self._save()
curr_step, curr_val = self._get_stat()
if curr_step is None:
return
if self.best is None or (curr_val > self.best[1] if self.reverse else curr_val < self.best[1]):
self.best = (curr_step, curr_val)
self._save()
def _save(self):
ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?")
"[MinSaver] Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = ckpt.model_checkpoint_path
extreme_name = 'maximum' if self.reverse else 'minimum'
if not path.endswith(str(self.best[0])):
logger.warn("[MinSaver] New {} '{}' found at global_step={}, but the latest checkpoint is {}.".format(
extreme_name, self.monitor_stat, self.best[0], path
))
logger.warn("MinSaver will do nothing this time. "
"The callbacks may have inconsistent frequency or wrong order.")
return
newname = os.path.join(self.checkpoint_dir,
self.filename or
('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat))
files_to_copy = tf.gfile.Glob(path + '*')
for file_to_copy in files_to_copy:
tf.gfile.Copy(file_to_copy, file_to_copy.replace(path, newname), overwrite=True)
logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat))
logger.info("Model at global_step={} with {} {}={:.5g} saved.".format(
self.best[0], extreme_name, self.monitor_stat, self.best[1]))
class MaxSaver(MinSaver):
"""
Separately save the model with maximum value of some statistics.
See docs of :class:`MinSaver` for details.
"""
def __init__(self, monitor_stat, filename=None, checkpoint_dir=None):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment