Commit ce57a145 authored by Yuxin Wu's avatar Yuxin Wu

min/max saver

parent f74ba9a1
......@@ -5,10 +5,9 @@
2. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk)
3. `ENV=NAME_OF_ENV ./run-atari.py --load "$ENV".tfmodel --env "$ENV"`
<!--
-Models are available for the following gym environments:
-
-+ [Breakout-v0](https://gym.openai.com/envs/Breakout-v0)
-->
Models are available for the following gym environments:
Note that atari game settings in gym is very different from DeepMind papers, therefore the scores are not comparable.
+ [Breakout-v0](https://gym.openai.com/envs/Breakout-v0)
+ [AirRaid-v0](https://gym.openai.com/envs/AirRaid-v0)
Note that atari game settings in gym is more difficult than the settings DeepMind papers, therefore the scores are not comparable.
......@@ -10,7 +10,7 @@ from .base import Callback
from ..utils import *
from ..tfutils.varmanip import get_savename_from_varname
__all__ = ['ModelSaver']
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Callback):
"""
......@@ -81,15 +81,22 @@ because {} will be saved".format(v.name, var_dict[name].name))
logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Callback):
def __init__(self, monitor_stat):
def __init__(self, monitor_stat, reverse=True):
self.monitor_stat = monitor_stat
self.reverse = reverse
self.min = None
def _get_stat(self):
return self.trainer.stat_holder.get_stat_now(self.monitor_stat)
def _need_save(self):
if self.reverse:
return self._get_stat() > self.min
else:
return self._get_stat() < self.min
def _trigger_epoch(self):
if self.min is None or self._get_stat() < self.min:
if self.min is None or self._need_save():
self.min = self._get_stat()
self._save()
......@@ -97,10 +104,17 @@ class MinSaver(Callback):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver before MinSaver?")
"Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = chpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR, 'min_' + self.monitor_stat)
newname = os.path.join(logger.LOG_DIR,
'max_' if self.reverse else 'min_' + self.monitor_stat)
shutil.copy(path, newname)
logger.info("Model with minimum {} saved.".format(self.monitor_stat))
logger.info("Model with {} {} saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat))
class MaxSaver(MinSaver):
def __init__(self, monitor_stat):
super(MaxSaver, self).__init__(monitor_stat, True)
......@@ -67,9 +67,9 @@ def set_logger_dir(dirname, action=None):
if os.path.isdir(dirname):
if not action:
_logger.warn("""\
Directory {} exists! Please either backup/delete it, or use a new directory.""")
Directory {} exists! Please either backup/delete it, or use a new directory.""".format(dirname))
_logger.warn("""\
If you're resuming from a previous run you can choose to keep it.""".format(dirname))
If you're resuming from a previous run you can choose to keep it.""")
_logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
while not action:
action = input().lower().strip()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment