Commit be3409fb authored by Yuxin Wu's avatar Yuxin Wu

execute_only_once and maxsaver

parent a4867550
#!/bin/bash -e
# File: update.sh
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
make clean
sphinx-apidoc -o modules ../tensorpack -f -d 10
make html
......@@ -10,7 +10,8 @@ from collections import deque
import threading
import six
from six.moves import range
from tensorpack.utils import get_rng, logger, memoized, get_dataset_path
from tensorpack.utils import (get_rng, logger, memoized,
get_dataset_path, execute_only_once)
from tensorpack.utils.stat import StatCounter
from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace
......@@ -19,10 +20,6 @@ from ale_python_interface import ALEInterface
__all__ = ['AtariPlayer']
@memoized
def log_once():
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
_ALE_LOCK = threading.Lock()
......@@ -56,7 +53,8 @@ class AtariPlayer(RLEnvironment):
try:
ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
except AttributeError:
log_once()
if execute_only_once():
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
# avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
with _ALE_LOCK:
......
......@@ -88,25 +88,30 @@ class MinSaver(Callback):
self.min = None
def _get_stat(self):
return self.trainer.stat_holder.get_stat_now(self.monitor_stat)
try:
v = self.trainer.stat_holder.get_stat_now(self.monitor_stat)
except KeyError:
v = None
return v
def _need_save(self):
if self.reverse:
return self._get_stat() > self.min
else:
return self._get_stat() < self.min
v = self._get_stat()
if not v:
return False
return v > self.min if self.reverse else v < self.min
def _trigger_epoch(self):
if self.min is None or self._need_save():
self.min = self._get_stat()
self._save()
if self.min:
self._save()
def _save(self):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = chpt.model_checkpoint_path
path = ckpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR,
self.filename or
('max-' if self.reverse else 'min-' + self.monitor_stat + '.tfmodel'))
......
......@@ -15,23 +15,10 @@ __all__ = ['change_env',
'map_arg',
'get_rng', 'memoized',
'get_dataset_path',
'get_tqdm_kwargs'
'get_tqdm_kwargs',
'execute_only_once'
]
#def expand_dim_if_necessary(var, dp):
# """
# Args:
# var: a tensor
# dp: a numpy array
# Return a reshaped version of dp, if that makes it match the valid dimension of var
# """
# shape = var.get_shape().as_list()
# valid_shape = [k for k in shape if k]
# if dp.shape == tuple(valid_shape):
# new_shape = [k if k else 1 for k in shape]
# dp = dp.reshape(new_shape)
# return dp
@contextmanager
def change_env(name, val):
oldval = os.environ.get(name, None)
......@@ -104,13 +91,24 @@ def get_rng(obj=None):
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed)
_EXECUTE_HISTORY = set()
def execute_only_once():
f = inspect.currentframe().f_back
ident = (f.f_code.co_filename, f.f_lineno)
if ident in _EXECUTE_HISTORY:
return False
_EXECUTE_HISTORY.add(ident)
return True
def get_dataset_path(*args):
from . import logger
d = os.environ.get('TENSORPACK_DATASET', None)
if d is None:
d = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', 'dataflow', 'dataset'))
logger.info("TENSORPACK_DATASET not set, using {} for dataset.".format(d))
if execute_only_once():
from . import logger
logger.info("TENSORPACK_DATASET not set, using {} for dataset.".format(d))
assert os.path.isdir(d), d
return os.path.join(d, *args)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment