Commit ce57a145 authored by Yuxin Wu's avatar Yuxin Wu

min/max saver

parent f74ba9a1
...@@ -5,10 +5,9 @@ ...@@ -5,10 +5,9 @@
2. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk) 2. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk)
3. `ENV=NAME_OF_ENV ./run-atari.py --load "$ENV".tfmodel --env "$ENV"` 3. `ENV=NAME_OF_ENV ./run-atari.py --load "$ENV".tfmodel --env "$ENV"`
<!-- Models are available for the following gym environments:
-Models are available for the following gym environments:
-
-+ [Breakout-v0](https://gym.openai.com/envs/Breakout-v0)
-->
Note that atari game settings in gym is very different from DeepMind papers, therefore the scores are not comparable. + [Breakout-v0](https://gym.openai.com/envs/Breakout-v0)
+ [AirRaid-v0](https://gym.openai.com/envs/AirRaid-v0)
Note that atari game settings in gym is more difficult than the settings DeepMind papers, therefore the scores are not comparable.
...@@ -10,7 +10,7 @@ from .base import Callback ...@@ -10,7 +10,7 @@ from .base import Callback
from ..utils import * from ..utils import *
from ..tfutils.varmanip import get_savename_from_varname from ..tfutils.varmanip import get_savename_from_varname
__all__ = ['ModelSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
class ModelSaver(Callback): class ModelSaver(Callback):
""" """
...@@ -81,15 +81,22 @@ because {} will be saved".format(v.name, var_dict[name].name)) ...@@ -81,15 +81,22 @@ because {} will be saved".format(v.name, var_dict[name].name))
logger.exception("Exception in ModelSaver.trigger_epoch!") logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Callback): class MinSaver(Callback):
def __init__(self, monitor_stat): def __init__(self, monitor_stat, reverse=True):
self.monitor_stat = monitor_stat self.monitor_stat = monitor_stat
self.reverse = reverse
self.min = None self.min = None
def _get_stat(self): def _get_stat(self):
return self.trainer.stat_holder.get_stat_now(self.monitor_stat) return self.trainer.stat_holder.get_stat_now(self.monitor_stat)
def _need_save(self):
if self.reverse:
return self._get_stat() > self.min
else:
return self._get_stat() < self.min
def _trigger_epoch(self): def _trigger_epoch(self):
if self.min is None or self._get_stat() < self.min: if self.min is None or self._need_save():
self.min = self._get_stat() self.min = self._get_stat()
self._save() self._save()
...@@ -97,10 +104,17 @@ class MinSaver(Callback): ...@@ -97,10 +104,17 @@ class MinSaver(Callback):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR) ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None: if ckpt is None:
raise RuntimeError( raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver before MinSaver?") "Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = chpt.model_checkpoint_path path = chpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR, 'min_' + self.monitor_stat) newname = os.path.join(logger.LOG_DIR,
'max_' if self.reverse else 'min_' + self.monitor_stat)
shutil.copy(path, newname) shutil.copy(path, newname)
logger.info("Model with minimum {} saved.".format(self.monitor_stat)) logger.info("Model with {} {} saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat))
class MaxSaver(MinSaver):
def __init__(self, monitor_stat):
super(MaxSaver, self).__init__(monitor_stat, True)
...@@ -67,9 +67,9 @@ def set_logger_dir(dirname, action=None): ...@@ -67,9 +67,9 @@ def set_logger_dir(dirname, action=None):
if os.path.isdir(dirname): if os.path.isdir(dirname):
if not action: if not action:
_logger.warn("""\ _logger.warn("""\
Directory {} exists! Please either backup/delete it, or use a new directory.""") Directory {} exists! Please either backup/delete it, or use a new directory.""".format(dirname))
_logger.warn("""\ _logger.warn("""\
If you're resuming from a previous run you can choose to keep it.""".format(dirname)) If you're resuming from a previous run you can choose to keep it.""")
_logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):") _logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
while not action: while not action:
action = input().lower().strip() action = input().lower().strip()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment