Commit 7e963996 authored by Yuxin Wu's avatar Yuxin Wu

deprecated tensorpack.RL and use gym for RL examples

parent c270a1ed
...@@ -8,6 +8,8 @@ so you won't need to look at here very often. ...@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here. TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/10/10](https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc).
`tfutils.distributions` was deprecated in favor of `tf.distributions` introduced in TF 1.3.
+ [2017/08/02](https://github.com/ppwwyyxx/tensorpack/commit/875f4d7dbb5675f54eae5675fa3a0948309a8465). + [2017/08/02](https://github.com/ppwwyyxx/tensorpack/commit/875f4d7dbb5675f54eae5675fa3a0948309a8465).
`Trainer.get_predictor` now takes GPU id. And `Trainer.get_predictors` was deprecated. `Trainer.get_predictor` now takes GPU id. And `Trainer.get_predictors` was deprecated.
+ 2017/06/07. Now the library explicitly depends on msgpack-numpy>=0.3.9. The serialization protocol + 2017/06/07. Now the library explicitly depends on msgpack-numpy>=0.3.9. The serialization protocol
......
...@@ -78,18 +78,21 @@ class SimulatorProcessStateExchange(SimulatorProcessBase): ...@@ -78,18 +78,21 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
s2c_socket = context.socket(zmq.DEALER) s2c_socket = context.socket(zmq.DEALER)
s2c_socket.setsockopt(zmq.IDENTITY, self.identity) s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
# s2c_socket.set_hwm(5)
s2c_socket.connect(self.s2c) s2c_socket.connect(self.s2c)
state = player.current_state() state = player.reset()
reward, isOver = 0, False reward, isOver = 0, False
while True: while True:
# after taking the last action, get to this state and get this reward/isOver.
# If isOver, get to the next-episode state immediately.
# This tuple is not the same as the one put into the memory buffer
c2s_socket.send(dumps( c2s_socket.send(dumps(
(self.identity, state, reward, isOver)), (self.identity, state, reward, isOver)),
copy=False) copy=False)
action = loads(s2c_socket.recv(copy=False).bytes) action = loads(s2c_socket.recv(copy=False).bytes)
reward, isOver = player.action(action) state, reward, isOver, _ = player.step(action)
state = player.current_state() if isOver:
state = player.reset()
# compatibility # compatibility
...@@ -180,17 +183,16 @@ class SimulatorMaster(threading.Thread): ...@@ -180,17 +183,16 @@ class SimulatorMaster(threading.Thread):
if __name__ == '__main__': if __name__ == '__main__':
import random import random
from tensorpack.RL import NaiveRLEnvironment import gym
class NaiveSimulator(SimulatorProcess): class NaiveSimulator(SimulatorProcess):
def _build_player(self): def _build_player(self):
return NaiveRLEnvironment() return gym.make('Breakout-v0')
class NaiveActioner(SimulatorMaster): class NaiveActioner(SimulatorMaster):
def _get_action(self, state): def _get_action(self, state):
time.sleep(1) time.sleep(1)
return random.randint(1, 12) return random.randint(1, 3)
def _on_episode_over(self, client): def _on_episode_over(self, client):
# print("Over: ", client.memory) # print("Over: ", client.memory)
......
...@@ -27,11 +27,12 @@ from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient ...@@ -27,11 +27,12 @@ from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.RL import * import gym
from simulator import * from simulator import *
import common import common
from common import (play_model, Evaluator, eval_model_multithread, from common import (Evaluator, eval_model_multithread,
play_one_episode, play_n_episodes) play_one_episode, play_n_episodes,
WarpFrame, FrameStack, FireResetEnv, LimitLength)
if six.PY3: if six.PY3:
from concurrent import futures from concurrent import futures
...@@ -58,15 +59,16 @@ NUM_ACTIONS = None ...@@ -58,15 +59,16 @@ NUM_ACTIONS = None
ENV_NAME = None ENV_NAME = None
def get_player(viz=False, train=False, dumpdir=None): def get_player(train=False, dumpdir=None):
pl = GymEnv(ENV_NAME, viz=viz, dumpdir=dumpdir) env = gym.make(ENV_NAME)
pl = MapPlayerState(pl, lambda img: cv2.resize(img, IMAGE_SIZE[::-1])) if dumpdir:
pl = HistoryFramePlayer(pl, FRAME_HISTORY) env = gym.wrappers.Monitor(env, dumpdir)
if not train: env = FireResetEnv(env)
pl = PreventStuckPlayer(pl, 30, 1) env = WarpFrame(env, IMAGE_SIZE)
else: env = FrameStack(env, 4)
pl = LimitLengthPlayer(pl, 60000) if train:
return pl env = LimitLength(env, 60000)
return env
class MySimulatorWorker(SimulatorProcess): class MySimulatorWorker(SimulatorProcess):
...@@ -272,7 +274,7 @@ if __name__ == '__main__': ...@@ -272,7 +274,7 @@ if __name__ == '__main__':
ENV_NAME = args.env ENV_NAME = args.env
logger.info("Environment Name: {}".format(ENV_NAME)) logger.info("Environment Name: {}".format(ENV_NAME))
NUM_ACTIONS = get_player().get_action_space().num_actions() NUM_ACTIONS = get_player().action_space.n
logger.info("Number of actions: {}".format(NUM_ACTIONS)) logger.info("Number of actions: {}".format(NUM_ACTIONS))
if args.gpu: if args.gpu:
...@@ -280,20 +282,21 @@ if __name__ == '__main__': ...@@ -280,20 +282,21 @@ if __name__ == '__main__':
if args.task != 'train': if args.task != 'train':
assert args.load is not None assert args.load is not None
cfg = PredictConfig( pred = OfflinePredictor(PredictConfig(
model=Model(), model=Model(),
session_init=get_model_loader(args.load), session_init=get_model_loader(args.load),
input_names=['state'], input_names=['state'],
output_names=['policy']) output_names=['policy']))
if args.task == 'play': if args.task == 'play':
play_model(cfg, get_player(viz=0.01)) play_n_episodes(get_player(train=False), pred,
args.episode, render=True)
elif args.task == 'eval': elif args.task == 'eval':
eval_model_multithread(cfg, args.episode, get_player) eval_model_multithread(pred, args.episode, get_player)
elif args.task == 'gen_submit': elif args.task == 'gen_submit':
play_n_episodes( play_n_episodes(
get_player(train=False, dumpdir=args.output), get_player(train=False, dumpdir=args.output),
OfflinePredictor(cfg), args.episode) pred, args.episode)
# gym.upload(output, api_key='xxx') # gym.upload(args.output, api_key='xxx')
else: else:
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME)) dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname) logger.set_logger_dir(dirname)
......
...@@ -18,14 +18,14 @@ from collections import deque ...@@ -18,14 +18,14 @@ from collections import deque
from tensorpack import * from tensorpack import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.RL import *
import tensorflow as tf import tensorflow as tf
from DQNModel import Model as DQNModel from DQNModel import Model as DQNModel
import common import common
from common import play_model, Evaluator, eval_model_multithread from common import Evaluator, eval_model_multithread, play_n_episodes
from atari import AtariPlayer from common import FrameStack, WarpFrame, FireResetEnv
from expreplay import ExpReplay from expreplay import ExpReplay
from atari import AtariPlayer
BATCH_SIZE = 64 BATCH_SIZE = 64
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
...@@ -37,7 +37,7 @@ GAMMA = 0.99 ...@@ -37,7 +37,7 @@ GAMMA = 0.99
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
# will consume at least 1e6 * 84 * 84 bytes == 6.6G memory. # will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
INIT_MEMORY_SIZE = 5e4 INIT_MEMORY_SIZE = MEMORY_SIZE // 20
STEPS_PER_EPOCH = 10000 // UPDATE_FREQ * 10 # each epoch is 100k played frames STEPS_PER_EPOCH = 10000 // UPDATE_FREQ * 10 # each epoch is 100k played frames
EVAL_EPISODE = 50 EVAL_EPISODE = 50
...@@ -47,17 +47,14 @@ METHOD = None ...@@ -47,17 +47,14 @@ METHOD = None
def get_player(viz=False, train=False): def get_player(viz=False, train=False):
pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz,
image_shape=IMAGE_SIZE[::-1], viz=viz, live_lost_as_eoe=train) live_lost_as_eoe=train, max_num_frames=30000)
env = FireResetEnv(env)
env = WarpFrame(env, IMAGE_SIZE)
if not train: if not train:
# create a new axis to stack history on
pl = MapPlayerState(pl, lambda im: im[:, :, np.newaxis])
# in training, history is taken care of in expreplay buffer # in training, history is taken care of in expreplay buffer
pl = HistoryFramePlayer(pl, FRAME_HISTORY) env = FrameStack(env, FRAME_HISTORY)
return env
pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 30000)
return pl
class Model(DQNModel): class Model(DQNModel):
...@@ -149,20 +146,20 @@ if __name__ == '__main__': ...@@ -149,20 +146,20 @@ if __name__ == '__main__':
ROM_FILE = args.rom ROM_FILE = args.rom
METHOD = args.algo METHOD = args.algo
# set num_actions # set num_actions
NUM_ACTIONS = AtariPlayer(ROM_FILE).get_action_space().num_actions() NUM_ACTIONS = AtariPlayer(ROM_FILE).action_space.n
logger.info("ROM: {}, Num Actions: {}".format(ROM_FILE, NUM_ACTIONS)) logger.info("ROM: {}, Num Actions: {}".format(ROM_FILE, NUM_ACTIONS))
if args.task != 'train': if args.task != 'train':
assert args.load is not None assert args.load is not None
cfg = PredictConfig( pred = OfflinePredictor(PredictConfig(
model=Model(), model=Model(),
session_init=get_model_loader(args.load), session_init=get_model_loader(args.load),
input_names=['state'], input_names=['state'],
output_names=['Qvalue']) output_names=['Qvalue']))
if args.task == 'play': if args.task == 'play':
play_model(cfg, get_player(viz=0.01)) play_n_episodes(get_player(viz=0.01), pred, 100)
elif args.task == 'eval': elif args.task == 'eval':
eval_model_multithread(cfg, EVAL_EPISODE, get_player) eval_model_multithread(pred, EVAL_EPISODE, get_player)
else: else:
logger.set_logger_dir( logger.set_logger_dir(
os.path.join('train_log', 'DQN-{}'.format( os.path.join('train_log', 'DQN-{}'.format(
......
...@@ -7,7 +7,6 @@ import numpy as np ...@@ -7,7 +7,6 @@ import numpy as np
import time import time
import os import os
import cv2 import cv2
from collections import deque
import threading import threading
import six import six
from six.moves import range from six.moves import range
...@@ -16,7 +15,9 @@ from tensorpack.utils.utils import get_rng, execute_only_once ...@@ -16,7 +15,9 @@ from tensorpack.utils.utils import get_rng, execute_only_once
from tensorpack.utils.fs import get_dataset_path from tensorpack.utils.fs import get_dataset_path
from tensorpack.utils.stats import StatCounter from tensorpack.utils.stats import StatCounter
from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace import gym
from gym import spaces
from gym.envs.atari.atari_env import ACTION_MEANING
from ale_python_interface import ALEInterface from ale_python_interface import ALEInterface
...@@ -26,27 +27,29 @@ ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms" ...@@ -26,27 +27,29 @@ ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
_ALE_LOCK = threading.Lock() _ALE_LOCK = threading.Lock()
class AtariPlayer(RLEnvironment): class AtariPlayer(gym.Env):
""" """
A wrapper for atari emulator. A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings.
Will automatically restart when a real episode ends (isOver might be just
lost of lives but not game over). Info:
score: the accumulated reward in the current game
gameOver: True when the current game is Over
""" """
def __init__(self, rom_file, viz=0, height_range=(None, None), def __init__(self, rom_file, viz=0,
frame_skip=4, image_shape=(84, 84), nullop_start=30, frame_skip=4, nullop_start=30,
live_lost_as_eoe=True): live_lost_as_eoe=True, max_num_frames=0):
""" """
:param rom_file: path to the rom Args:
:param frame_skip: skip every k frames and repeat the action rom_file: path to the rom
:param image_shape: (w, h) frame_skip: skip every k frames and repeat the action
:param height_range: (h1, h2) to cut viz: visualization to be done.
:param viz: visualization to be done.
Set to 0 to disable. Set to 0 to disable.
Set to a positive number to be the delay between frames to show. Set to a positive number to be the delay between frames to show.
Set to a string to be a directory to store frames. Set to a string to be a directory to store frames.
:param nullop_start: start with random number of null ops nullop_start: start with random number of null ops.
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training. live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
max_num_frames: maximum number of frames per episode.
""" """
super(AtariPlayer, self).__init__() super(AtariPlayer, self).__init__()
if not os.path.isfile(rom_file) and '/' not in rom_file: if not os.path.isfile(rom_file) and '/' not in rom_file:
...@@ -65,6 +68,7 @@ class AtariPlayer(RLEnvironment): ...@@ -65,6 +68,7 @@ class AtariPlayer(RLEnvironment):
self.ale = ALEInterface() self.ale = ALEInterface()
self.rng = get_rng(self) self.rng = get_rng(self)
self.ale.setInt(b"random_seed", self.rng.randint(0, 30000)) self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
self.ale.setBool(b"showinfo", False) self.ale.setBool(b"showinfo", False)
self.ale.setInt(b"frame_skip", 1) self.ale.setInt(b"frame_skip", 1)
...@@ -92,11 +96,16 @@ class AtariPlayer(RLEnvironment): ...@@ -92,11 +96,16 @@ class AtariPlayer(RLEnvironment):
self.live_lost_as_eoe = live_lost_as_eoe self.live_lost_as_eoe = live_lost_as_eoe
self.frame_skip = frame_skip self.frame_skip = frame_skip
self.nullop_start = nullop_start self.nullop_start = nullop_start
self.height_range = height_range
self.image_shape = image_shape
self.current_episode_score = StatCounter() self.current_episode_score = StatCounter()
self.restart_episode()
self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width))
self._restart_episode()
def get_action_meanings(self):
return [ACTION_MEANING[i] for i in self.actions]
def _grab_raw_image(self): def _grab_raw_image(self):
""" """
...@@ -105,7 +114,7 @@ class AtariPlayer(RLEnvironment): ...@@ -105,7 +114,7 @@ class AtariPlayer(RLEnvironment):
m = self.ale.getScreenRGB() m = self.ale.getScreenRGB()
return m.reshape((self.height, self.width, 3)) return m.reshape((self.height, self.width, 3))
def current_state(self): def _current_state(self):
""" """
:returns: a gray-scale (h, w) uint8 image :returns: a gray-scale (h, w) uint8 image
""" """
...@@ -116,19 +125,12 @@ class AtariPlayer(RLEnvironment): ...@@ -116,19 +125,12 @@ class AtariPlayer(RLEnvironment):
if isinstance(self.viz, float): if isinstance(self.viz, float):
cv2.imshow(self.windowname, ret) cv2.imshow(self.windowname, ret)
time.sleep(self.viz) time.sleep(self.viz)
ret = ret[self.height_range[0]:self.height_range[1], :].astype('float32') ret = ret.astype('float32')
# 0.299,0.587.0.114. same as rgb2y in torch/image # 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.resize(ret, self.image_shape)
return ret.astype('uint8') # to save some memory return ret.astype('uint8') # to save some memory
def get_action_space(self): def _restart_episode(self):
return DiscreteActionSpace(len(self.actions))
def finish_episode(self):
self.stats['score'].append(self.current_episode_score.sum)
def restart_episode(self):
self.current_episode_score.reset() self.current_episode_score.reset()
with _ALE_LOCK: with _ALE_LOCK:
self.ale.reset_game() self.ale.reset_game()
...@@ -141,11 +143,12 @@ class AtariPlayer(RLEnvironment): ...@@ -141,11 +143,12 @@ class AtariPlayer(RLEnvironment):
self.last_raw_screen = self._grab_raw_image() self.last_raw_screen = self._grab_raw_image()
self.ale.act(0) self.ale.act(0)
def action(self, act): def _reset(self):
""" if self.ale.game_over():
:param act: an index of the action self._restart_episode()
:returns: (reward, isOver) return self._current_state()
"""
def _step(self, act):
oldlives = self.ale.lives() oldlives = self.ale.lives()
r = 0 r = 0
for k in range(self.frame_skip): for k in range(self.frame_skip):
...@@ -158,55 +161,24 @@ class AtariPlayer(RLEnvironment): ...@@ -158,55 +161,24 @@ class AtariPlayer(RLEnvironment):
break break
self.current_episode_score.feed(r) self.current_episode_score.feed(r)
isOver = self.ale.game_over() trueIsOver = isOver = self.ale.game_over()
if self.live_lost_as_eoe: if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives isOver = isOver or newlives < oldlives
if isOver:
self.finish_episode() info = {'score': self.current_episode_score.sum, 'gameOver': trueIsOver}
if self.ale.game_over(): return self._current_state(), r, isOver, info
self.restart_episode()
return (r, isOver)
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
def benchmark(): a = AtariPlayer(sys.argv[1], viz=0.03)
a = AtariPlayer(sys.argv[1], viz=False, height_range=(28, -8)) num = a.action_space.n
num = a.get_action_space().num_actions()
rng = get_rng(num)
start = time.time()
cnt = 0
while True:
act = rng.choice(range(num))
r, o = a.action(act)
a.current_state()
cnt += 1
if cnt == 5000:
break
print(time.time() - start)
if len(sys.argv) == 3 and sys.argv[2] == 'benchmark':
import threading
import multiprocessing
for k in range(3):
# th = multiprocessing.Process(target=benchmark)
th = threading.Thread(target=benchmark)
th.start()
time.sleep(0.02)
benchmark()
else:
a = AtariPlayer(sys.argv[1],
viz=0.03, height_range=(28, -8))
num = a.get_action_space().num_actions()
rng = get_rng(num) rng = get_rng(num)
import time
while True: while True:
# im = a.grab_image()
# cv2.imshow(a.romname, im)
act = rng.choice(range(num)) act = rng.choice(range(num))
print(act) state, reward, isOver, info = a.step(act)
r, o = a.action(act) if isOver:
a.current_state() print(info)
# time.sleep(0.1) a.reset()
print(r, o) print("Reward:", reward)
...@@ -7,35 +7,56 @@ import time ...@@ -7,35 +7,56 @@ import time
import threading import threading
import multiprocessing import multiprocessing
import numpy as np import numpy as np
import cv2
from collections import deque
from tqdm import tqdm from tqdm import tqdm
from six.moves import queue from six.moves import queue
from tensorpack import * import gym
from tensorpack.utils.concurrency import * from gym import spaces
from tensorpack.utils.stats import *
from tensorpack.utils.concurrency import StoppableThread, ShareSessionThread
from tensorpack.callbacks import Triggerable
from tensorpack.utils import logger
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_tqdm_kwargs from tensorpack.utils.utils import get_tqdm_kwargs
def play_one_episode(player, func, verbose=False): def play_one_episode(env, func, render=False):
def f(s): def predict(s):
spc = player.get_action_space() """
Map from observation to action, with 0.001 greedy.
"""
act = func([[s]])[0][0].argmax() act = func([[s]])[0][0].argmax()
if random.random() < 0.001: if random.random() < 0.001:
spc = env.action_space
act = spc.sample() act = spc.sample()
if verbose:
print(act)
return act return act
return np.mean(player.play_one_episode(f))
def play_model(cfg, player): ob = env.reset()
predfunc = OfflinePredictor(cfg) sum_r = 0
while True: while True:
score = play_one_episode(player, predfunc) act = predict(ob)
print("Total:", score) ob, r, isOver, info = env.step(act)
if render:
env.render()
sum_r += r
if isOver:
return sum_r
def play_n_episodes(player, predfunc, nr, render=False):
logger.info("Start Playing ... ")
for k in range(nr):
score = play_one_episode(player, predfunc, render=render)
print("{}/{}, score={}".format(k, nr, score))
def eval_with_funcs(predictors, nr_eval, get_player_fn): def eval_with_funcs(predictors, nr_eval, get_player_fn):
"""
Args:
predictors ([PredictorBase])
"""
class Worker(StoppableThread, ShareSessionThread): class Worker(StoppableThread, ShareSessionThread):
def __init__(self, func, queue): def __init__(self, func, queue):
super(Worker, self).__init__() super(Worker, self).__init__()
...@@ -85,10 +106,14 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn): ...@@ -85,10 +106,14 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
return (0, 0) return (0, 0)
def eval_model_multithread(cfg, nr_eval, get_player_fn): def eval_model_multithread(pred, nr_eval, get_player_fn):
func = OfflinePredictor(cfg) """
Args:
pred (OfflinePredictor): state -> Qvalue
"""
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC, nr_eval, get_player_fn) with pred.sess.as_default():
mean, max = eval_with_funcs([pred] * NR_PROC, nr_eval, get_player_fn)
logger.info("Average Score: {}; Max Score: {}".format(mean, max)) logger.info("Average Score: {}; Max Score: {}".format(mean, max))
...@@ -115,10 +140,103 @@ class Evaluator(Triggerable): ...@@ -115,10 +140,103 @@ class Evaluator(Triggerable):
self.trainer.monitors.put_scalar('max_score', max) self.trainer.monitors.put_scalar('max_score', max)
def play_n_episodes(player, predfunc, nr): """
logger.info("Start evaluation: ") ------------------------------------------------------------------------------
for k in range(nr): The following wrappers are copied or modified from openai/baselines:
if k != 0: https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
player.restart_episode() """
score = play_one_episode(player, predfunc)
print("{}/{}, score={}".format(k, nr, score))
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, shape):
gym.ObservationWrapper.__init__(self, env)
self.shape = shape
obs = env.observation_space
assert isinstance(obs, spaces.Box)
chan = 1 if len(obs.shape) == 2 else obs.shape[2]
shape3d = shape if chan == 1 else shape + (chan,)
self.observation_space = spaces.Box(low=0, high=255, shape=shape3d)
def _observation(self, obs):
return cv2.resize(obs, self.shape)
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self._base_dim = len(shp)
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k))
def _reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob))
self.frames.append(ob)
return self._observation()
def _step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._observation(), reward, done, info
def _observation(self):
assert len(self.frames) == self.k
if self._base_dim == 2:
return np.stack(self.frames, axis=-1)
else:
return np.concatenate(self.frames, axis=2)
class _FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
def FireResetEnv(env):
if isinstance(env, gym.Wrapper):
baseenv = env.unwrapped
else:
baseenv = env
if 'FIRE' in baseenv.get_action_meanings():
return _FireResetEnv(env)
return env
class LimitLength(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self.k = k
def _reset(self):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob = self.env.reset()
self.cnt = 0
return ob
def _step(self, action):
ob, r, done, info = self.env.step(action)
self.cnt += 1
if self.cnt == self.k:
done = True
return ob, r, done, info
...@@ -13,6 +13,7 @@ from six.moves import queue, range ...@@ -13,6 +13,7 @@ from six.moves import queue, range
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.utils import get_tqdm, get_rng from tensorpack.utils.utils import get_tqdm, get_rng
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback from tensorpack.callbacks.base import Callback
...@@ -142,7 +143,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -142,7 +143,7 @@ class ExpReplay(DataFlow, Callback):
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
self.exploration = init_exploration self.exploration = init_exploration
self.num_actions = player.get_action_space().num_actions() self.num_actions = player.action_space.n
logger.info("Number of Legal actions: {}".format(self.num_actions)) logger.info("Number of Legal actions: {}".format(self.num_actions))
self.rng = get_rng(self) self.rng = get_rng(self)
...@@ -152,6 +153,8 @@ class ExpReplay(DataFlow, Callback): ...@@ -152,6 +153,8 @@ class ExpReplay(DataFlow, Callback):
self._populate_job_queue = queue.Queue(maxsize=5) self._populate_job_queue = queue.Queue(maxsize=5)
self.mem = ReplayMemory(memory_size, state_shape, history_len) self.mem = ReplayMemory(memory_size, state_shape, history_len)
self._current_ob = self.player.reset()
self._player_scores = StatCounter()
def get_simulator_thread(self): def get_simulator_thread(self):
# spawn a separate thread to run policy # spawn a separate thread to run policy
...@@ -186,7 +189,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -186,7 +189,7 @@ class ExpReplay(DataFlow, Callback):
def _populate_exp(self): def _populate_exp(self):
""" populate a transition by epsilon-greedy""" """ populate a transition by epsilon-greedy"""
old_s = self.player.current_state() old_s = self._current_ob
if self.rng.rand() <= self.exploration or (len(self.mem) <= self.history_len): if self.rng.rand() <= self.exploration or (len(self.mem) <= self.history_len):
act = self.rng.choice(range(self.num_actions)) act = self.rng.choice(range(self.num_actions))
else: else:
...@@ -198,7 +201,11 @@ class ExpReplay(DataFlow, Callback): ...@@ -198,7 +201,11 @@ class ExpReplay(DataFlow, Callback):
# assume batched network # assume batched network
q_values = self.predictor([[history]])[0][0] # this is the bottleneck q_values = self.predictor([[history]])[0][0] # this is the bottleneck
act = np.argmax(q_values) act = np.argmax(q_values)
reward, isOver = self.player.action(act) self._current_ob, reward, isOver, info = self.player.step(act)
if isOver:
if info['gameOver']: # only record score when a whole game is over (not when an episode is over)
self._player_scores.feed(info['score'])
self.player.reset()
self.mem.append(Experience(old_s, act, reward, isOver)) self.mem.append(Experience(old_s, act, reward, isOver))
def _debug_sample(self, sample): def _debug_sample(self, sample):
...@@ -245,17 +252,15 @@ class ExpReplay(DataFlow, Callback): ...@@ -245,17 +252,15 @@ class ExpReplay(DataFlow, Callback):
self._simulator_th = self.get_simulator_thread() self._simulator_th = self.get_simulator_thread()
self._simulator_th.start() self._simulator_th.start()
def _trigger_epoch(self): def _trigger(self):
# log player statistics in training v = self._player_scores
stats = self.player.stats
for k, v in six.iteritems(stats):
try: try:
mean, max = np.mean(v), np.max(v) mean, max = v.average, v.max
self.trainer.monitors.put_scalar('expreplay/mean_' + k, mean) self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
self.trainer.monitors.put_scalar('expreplay/max_' + k, max) self.trainer.monitors.put_scalar('expreplay/max_score', max)
except: except:
logger.exception("Cannot log training scores.") logger.exception("Cannot log training scores.")
self.player.reset_stat() v.reset()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import iter_modules from pkgutil import iter_modules
from ..utils.develop import log_deprecated
import os import os
import os.path import os.path
...@@ -13,6 +14,8 @@ __all__ = [] ...@@ -13,6 +14,8 @@ __all__ = []
This module should be removed in the future. This module should be removed in the future.
""" """
log_deprecated("tensorpack.RL", "Please use gym or other APIs instead!", "2017-12-31")
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment