Commit be3409fb authored by Yuxin Wu's avatar Yuxin Wu

execute_only_once and maxsaver

parent a4867550
#!/bin/bash -e
# File: update.sh
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
make clean
sphinx-apidoc -o modules ../tensorpack -f -d 10
make html
...@@ -10,7 +10,8 @@ from collections import deque ...@@ -10,7 +10,8 @@ from collections import deque
import threading import threading
import six import six
from six.moves import range from six.moves import range
from tensorpack.utils import get_rng, logger, memoized, get_dataset_path from tensorpack.utils import (get_rng, logger, memoized,
get_dataset_path, execute_only_once)
from tensorpack.utils.stat import StatCounter from tensorpack.utils.stat import StatCounter
from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace
...@@ -19,10 +20,6 @@ from ale_python_interface import ALEInterface ...@@ -19,10 +20,6 @@ from ale_python_interface import ALEInterface
__all__ = ['AtariPlayer'] __all__ = ['AtariPlayer']
@memoized
def log_once():
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms" ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
_ALE_LOCK = threading.Lock() _ALE_LOCK = threading.Lock()
...@@ -56,7 +53,8 @@ class AtariPlayer(RLEnvironment): ...@@ -56,7 +53,8 @@ class AtariPlayer(RLEnvironment):
try: try:
ALEInterface.setLoggerMode(ALEInterface.Logger.Warning) ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
except AttributeError: except AttributeError:
log_once() if execute_only_once():
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
# avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
with _ALE_LOCK: with _ALE_LOCK:
......
...@@ -88,17 +88,22 @@ class MinSaver(Callback): ...@@ -88,17 +88,22 @@ class MinSaver(Callback):
self.min = None self.min = None
def _get_stat(self): def _get_stat(self):
return self.trainer.stat_holder.get_stat_now(self.monitor_stat) try:
v = self.trainer.stat_holder.get_stat_now(self.monitor_stat)
except KeyError:
v = None
return v
def _need_save(self): def _need_save(self):
if self.reverse: v = self._get_stat()
return self._get_stat() > self.min if not v:
else: return False
return self._get_stat() < self.min return v > self.min if self.reverse else v < self.min
def _trigger_epoch(self): def _trigger_epoch(self):
if self.min is None or self._need_save(): if self.min is None or self._need_save():
self.min = self._get_stat() self.min = self._get_stat()
if self.min:
self._save() self._save()
def _save(self): def _save(self):
...@@ -106,7 +111,7 @@ class MinSaver(Callback): ...@@ -106,7 +111,7 @@ class MinSaver(Callback):
if ckpt is None: if ckpt is None:
raise RuntimeError( raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?") "Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = chpt.model_checkpoint_path path = ckpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR, newname = os.path.join(logger.LOG_DIR,
self.filename or self.filename or
('max-' if self.reverse else 'min-' + self.monitor_stat + '.tfmodel')) ('max-' if self.reverse else 'min-' + self.monitor_stat + '.tfmodel'))
......
...@@ -15,23 +15,10 @@ __all__ = ['change_env', ...@@ -15,23 +15,10 @@ __all__ = ['change_env',
'map_arg', 'map_arg',
'get_rng', 'memoized', 'get_rng', 'memoized',
'get_dataset_path', 'get_dataset_path',
'get_tqdm_kwargs' 'get_tqdm_kwargs',
'execute_only_once'
] ]
#def expand_dim_if_necessary(var, dp):
# """
# Args:
# var: a tensor
# dp: a numpy array
# Return a reshaped version of dp, if that makes it match the valid dimension of var
# """
# shape = var.get_shape().as_list()
# valid_shape = [k for k in shape if k]
# if dp.shape == tuple(valid_shape):
# new_shape = [k if k else 1 for k in shape]
# dp = dp.reshape(new_shape)
# return dp
@contextmanager @contextmanager
def change_env(name, val): def change_env(name, val):
oldval = os.environ.get(name, None) oldval = os.environ.get(name, None)
...@@ -104,12 +91,23 @@ def get_rng(obj=None): ...@@ -104,12 +91,23 @@ def get_rng(obj=None):
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295 int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed) return np.random.RandomState(seed)
_EXECUTE_HISTORY = set()
def execute_only_once():
f = inspect.currentframe().f_back
ident = (f.f_code.co_filename, f.f_lineno)
if ident in _EXECUTE_HISTORY:
return False
_EXECUTE_HISTORY.add(ident)
return True
def get_dataset_path(*args): def get_dataset_path(*args):
from . import logger
d = os.environ.get('TENSORPACK_DATASET', None) d = os.environ.get('TENSORPACK_DATASET', None)
if d is None: if d is None:
d = os.path.abspath(os.path.join( d = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', 'dataflow', 'dataset')) os.path.dirname(__file__), '..', 'dataflow', 'dataset'))
if execute_only_once():
from . import logger
logger.info("TENSORPACK_DATASET not set, using {} for dataset.".format(d)) logger.info("TENSORPACK_DATASET not set, using {} for dataset.".format(d))
assert os.path.isdir(d), d assert os.path.isdir(d), d
return os.path.join(d, *args) return os.path.join(d, *args)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment