Commit 0a13579c authored by Yuxin Wu's avatar Yuxin Wu

Delete tensorpack.RL

parent b2ff230d
## DEPRECATED
Please use gym or other APIs.
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import iter_modules
from ..utils.develop import log_deprecated
import os
import os.path
__all__ = []
"""
This module should be removed in the future.
"""
log_deprecated("tensorpack.RL", "Please use gym or other APIs instead!", "2017-12-31")
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
for _, module_name, _ in iter_modules(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
_global_import(module_name)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from collections import deque
from .envbase import ProxyPlayer
__all__ = ['PreventStuckPlayer', 'LimitLengthPlayer', 'AutoRestartPlayer',
'MapPlayerState']
class PreventStuckPlayer(ProxyPlayer):
""" Prevent the player from getting stuck (repeating a no-op)
by inserting a different action. Useful in games such as Atari Breakout
where the agent needs to press the 'start' button to start playing.
It does auto-reset, but doesn't auto-restart the underlying player.
"""
# TODO hash the state as well?
def __init__(self, player, nr_repeat, action):
"""
Args:
nr_repeat (int): trigger the 'action' after this many of repeated action.
action: the action to be triggered to get out of stuck.
"""
super(PreventStuckPlayer, self).__init__(player)
self.act_que = deque(maxlen=nr_repeat)
self.trigger_action = action
def action(self, act):
self.act_que.append(act)
if self.act_que.count(self.act_que[0]) == self.act_que.maxlen:
act = self.trigger_action
r, isOver = self.player.action(act)
if isOver:
self.act_que.clear()
return (r, isOver)
def restart_episode(self):
super(PreventStuckPlayer, self).restart_episode()
self.act_que.clear()
class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode.
Will restart the underlying player on timeout.
"""
def __init__(self, player, limit):
"""
Args:
limit(int): the time limit
"""
super(LimitLengthPlayer, self).__init__(player)
self.limit = limit
self.cnt = 0
def action(self, act):
r, isOver = self.player.action(act)
self.cnt += 1
if self.cnt >= self.limit:
isOver = True
self.finish_episode()
self.restart_episode()
if isOver:
self.cnt = 0
return (r, isOver)
def restart_episode(self):
self.player.restart_episode()
self.cnt = 0
class AutoRestartPlayer(ProxyPlayer):
""" Auto-restart the player on episode ends,
in case some player wasn't designed to do so.
"""
def action(self, act):
r, isOver = self.player.action(act)
if isOver:
self.player.finish_episode()
self.player.restart_episode()
return r, isOver
class MapPlayerState(ProxyPlayer):
""" Map the state of the underlying player by a function. """
def __init__(self, player, func):
"""
Args:
func: takes the old state and return a new state.
"""
super(MapPlayerState, self).__init__(player)
self.func = func
def current_state(self):
return self.func(self.player.current_state())
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: envbase.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta
from collections import defaultdict
import six
from ..utils.utils import get_rng
__all__ = ['RLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace']
@six.add_metaclass(ABCMeta)
class RLEnvironment(object):
""" Base class of RL environment. """
def __init__(self):
self.reset_stat()
@abstractmethod
def current_state(self):
"""
Observe, return a state representation
"""
@abstractmethod
def action(self, act):
"""
Perform an action. Will automatically start a new episode if isOver==True
Args:
act: the action
Returns:
tuple: (reward, isOver)
"""
def restart_episode(self):
""" Start a new episode, even if the current hasn't ended """
raise NotImplementedError()
def finish_episode(self):
""" Get called when an episode finished"""
pass
def get_action_space(self):
""" Returns:
:class:`ActionSpace` """
raise NotImplementedError()
def reset_stat(self):
""" Reset all statistics counter"""
self.stats = defaultdict(list)
def play_one_episode(self, func, stat='score'):
""" Play one episode for eval.
Args:
func: the policy function. Takes a state and returns an action.
stat: a key or list of keys in stats to return.
Returns:
the stat(s) after running this episode
"""
if not isinstance(stat, list):
stat = [stat]
while True:
s = self.current_state()
act = func(s)
r, isOver = self.action(act)
# print r
if isOver:
s = [self.stats[k] for k in stat]
self.reset_stat()
return s if len(s) > 1 else s[0]
class ActionSpace(object):
def __init__(self):
self.rng = get_rng(self)
@abstractmethod
def sample(self):
pass
def num_actions(self):
raise NotImplementedError()
class DiscreteActionSpace(ActionSpace):
def __init__(self, num):
super(DiscreteActionSpace, self).__init__()
self.num = num
def sample(self):
return self.rng.randint(self.num)
def num_actions(self):
return self.num
def __repr__(self):
return "DiscreteActionSpace({})".format(self.num)
def __str__(self):
return "DiscreteActionSpace({})".format(self.num)
class NaiveRLEnvironment(RLEnvironment):
""" For testing only"""
def __init__(self):
self.k = 0
def current_state(self):
self.k += 1
return self.k
def action(self, act):
self.k = act
return (self.k, self.k > 10)
class ProxyPlayer(RLEnvironment):
""" Serve as a proxy to another player """
def __init__(self, player):
self.player = player
def reset_stat(self):
self.player.reset_stat()
def current_state(self):
return self.player.current_state()
def action(self, act):
return self.player.action(act)
@property
def stats(self):
return self.player.stats
def restart_episode(self):
self.player.restart_episode()
def finish_episode(self):
self.player.finish_episode()
def get_action_space(self):
return self.player.get_action_space()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: gymenv.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import time
import threading
from ..utils.fs import mkdir_p
from ..utils.stats import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace
__all__ = ['GymEnv']
_ENV_LOCK = threading.Lock()
class GymEnv(RLEnvironment):
"""
An OpenAI/gym wrapper. Can optionally auto restart.
Only support discrete action space for now.
"""
def __init__(self, name, dumpdir=None, viz=False, auto_restart=True):
"""
Args:
name (str): the gym environment name.
dumpdir (str): the directory to dump recordings to.
viz (bool): whether to start visualization.
auto_restart (bool): whether to restart after episode ends.
"""
with _ENV_LOCK:
self.gymenv = gym.make(name)
if dumpdir:
mkdir_p(dumpdir)
self.gymenv = gym.wrappers.Monitor(self.gymenv, dumpdir)
self.use_dir = dumpdir
self.reset_stat()
self.rwd_counter = StatCounter()
self.restart_episode()
self.auto_restart = auto_restart
self.viz = viz
def restart_episode(self):
self.rwd_counter.reset()
self._ob = self.gymenv.reset()
def finish_episode(self):
self.stats['score'].append(self.rwd_counter.sum)
def current_state(self):
if self.viz:
self.gymenv.render()
time.sleep(self.viz)
return self._ob
def action(self, act):
self._ob, r, isOver, info = self.gymenv.step(act)
self.rwd_counter.feed(r)
if isOver:
self.finish_episode()
if self.auto_restart:
self.restart_episode()
return r, isOver
def get_action_space(self):
spc = self.gymenv.action_space
assert isinstance(spc, gym.spaces.discrete.Discrete)
return DiscreteActionSpace(spc.n)
try:
import gym
import gym.wrappers
# TODO
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
except ImportError:
from ..utils.develop import create_dummy_class
GymEnv = create_dummy_class('GymEnv', 'gym') # noqa
if __name__ == '__main__':
import gym_ple, cv2 # noqa
import os
os.environ["SDL_VIDEODRIVER"] = "dummy"
env = GymEnv('FlappyBird-v0', viz=0.1)
num = env.get_action_space().num_actions()
from ..utils.utils import get_rng
rng = get_rng(num)
while True:
act = rng.choice(range(num))
# print act
r, o = env.action(act)
state = env.current_state()
state = cv2.resize(state[:450], (84, 84))
cv2.imshow("aa", state)
cv2.waitKey(3)
print(state.shape)
if r != 0 or o:
print(r, o)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: history.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
from collections import deque
from six.moves import range
from .envbase import ProxyPlayer
__all__ = ['HistoryFramePlayer']
class HistoryBuffer(object):
def __init__(self, hist_len, concat_axis=2):
self.buf = deque(maxlen=hist_len)
self.concat_axis = concat_axis
def push(self, s):
self.buf.append(s)
def clear(self):
self.buf.clear()
def get(self):
difflen = self.buf.maxlen - len(self.buf)
if difflen == 0:
ret = self.buf
else:
zeros = [np.zeros_like(self.buf[0]) for k in range(difflen)]
for k in self.buf:
zeros.append(k)
ret = zeros
return np.concatenate(ret, axis=self.concat_axis)
def __len__(self):
return len(self.buf)
@property
def maxlen(self):
return self.buf.maxlen
class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images.
It assumes the underlying player will do auto-restart.
Map the original frames into (H, W, HIST x channels).
Oldest frames first.
"""
def __init__(self, player, hist_len):
"""
Args:
hist_len (int): total length of the state, including the current
and `hist_len-1` history.
"""
super(HistoryFramePlayer, self).__init__(player)
self.history = HistoryBuffer(hist_len, 2)
s = self.player.current_state()
self.history.push(s)
def current_state(self):
assert len(self.history) != 0
return self.history.get()
def action(self, act):
r, isOver = self.player.action(act)
s = self.player.current_state()
self.history.push(s)
if isOver: # s would be a new episode
self.history.clear()
self.history.push(s)
return (r, isOver)
def restart_episode(self):
super(HistoryFramePlayer, self).restart_episode()
self.history.clear()
self.history.push(self.player.current_state())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment