Commit 7e963996 authored by Yuxin Wu's avatar Yuxin Wu

deprecated tensorpack.RL and use gym for RL examples

parent c270a1ed
......@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/10/10](https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc).
`tfutils.distributions` was deprecated in favor of `tf.distributions` introduced in TF 1.3.
+ [2017/08/02](https://github.com/ppwwyyxx/tensorpack/commit/875f4d7dbb5675f54eae5675fa3a0948309a8465).
`Trainer.get_predictor` now takes GPU id. And `Trainer.get_predictors` was deprecated.
+ 2017/06/07. Now the library explicitly depends on msgpack-numpy>=0.3.9. The serialization protocol
......
......@@ -78,18 +78,21 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
s2c_socket = context.socket(zmq.DEALER)
s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
# s2c_socket.set_hwm(5)
s2c_socket.connect(self.s2c)
state = player.current_state()
state = player.reset()
reward, isOver = 0, False
while True:
# after taking the last action, get to this state and get this reward/isOver.
# If isOver, get to the next-episode state immediately.
# This tuple is not the same as the one put into the memory buffer
c2s_socket.send(dumps(
(self.identity, state, reward, isOver)),
copy=False)
action = loads(s2c_socket.recv(copy=False).bytes)
reward, isOver = player.action(action)
state = player.current_state()
state, reward, isOver, _ = player.step(action)
if isOver:
state = player.reset()
# compatibility
......@@ -180,17 +183,16 @@ class SimulatorMaster(threading.Thread):
if __name__ == '__main__':
import random
from tensorpack.RL import NaiveRLEnvironment
import gym
class NaiveSimulator(SimulatorProcess):
def _build_player(self):
return NaiveRLEnvironment()
return gym.make('Breakout-v0')
class NaiveActioner(SimulatorMaster):
def _get_action(self, state):
time.sleep(1)
return random.randint(1, 12)
return random.randint(1, 3)
def _on_episode_over(self, client):
# print("Over: ", client.memory)
......
......@@ -27,11 +27,12 @@ from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.RL import *
import gym
from simulator import *
import common
from common import (play_model, Evaluator, eval_model_multithread,
play_one_episode, play_n_episodes)
from common import (Evaluator, eval_model_multithread,
play_one_episode, play_n_episodes,
WarpFrame, FrameStack, FireResetEnv, LimitLength)
if six.PY3:
from concurrent import futures
......@@ -58,15 +59,16 @@ NUM_ACTIONS = None
ENV_NAME = None
def get_player(viz=False, train=False, dumpdir=None):
pl = GymEnv(ENV_NAME, viz=viz, dumpdir=dumpdir)
pl = MapPlayerState(pl, lambda img: cv2.resize(img, IMAGE_SIZE[::-1]))
pl = HistoryFramePlayer(pl, FRAME_HISTORY)
if not train:
pl = PreventStuckPlayer(pl, 30, 1)
else:
pl = LimitLengthPlayer(pl, 60000)
return pl
def get_player(train=False, dumpdir=None):
env = gym.make(ENV_NAME)
if dumpdir:
env = gym.wrappers.Monitor(env, dumpdir)
env = FireResetEnv(env)
env = WarpFrame(env, IMAGE_SIZE)
env = FrameStack(env, 4)
if train:
env = LimitLength(env, 60000)
return env
class MySimulatorWorker(SimulatorProcess):
......@@ -272,7 +274,7 @@ if __name__ == '__main__':
ENV_NAME = args.env
logger.info("Environment Name: {}".format(ENV_NAME))
NUM_ACTIONS = get_player().get_action_space().num_actions()
NUM_ACTIONS = get_player().action_space.n
logger.info("Number of actions: {}".format(NUM_ACTIONS))
if args.gpu:
......@@ -280,20 +282,21 @@ if __name__ == '__main__':
if args.task != 'train':
assert args.load is not None
cfg = PredictConfig(
pred = OfflinePredictor(PredictConfig(
model=Model(),
session_init=get_model_loader(args.load),
input_names=['state'],
output_names=['policy'])
output_names=['policy']))
if args.task == 'play':
play_model(cfg, get_player(viz=0.01))
play_n_episodes(get_player(train=False), pred,
args.episode, render=True)
elif args.task == 'eval':
eval_model_multithread(cfg, args.episode, get_player)
eval_model_multithread(pred, args.episode, get_player)
elif args.task == 'gen_submit':
play_n_episodes(
get_player(train=False, dumpdir=args.output),
OfflinePredictor(cfg), args.episode)
# gym.upload(output, api_key='xxx')
pred, args.episode)
# gym.upload(args.output, api_key='xxx')
else:
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
......
......@@ -18,14 +18,14 @@ from collections import deque
from tensorpack import *
from tensorpack.utils.concurrency import *
from tensorpack.RL import *
import tensorflow as tf
from DQNModel import Model as DQNModel
import common
from common import play_model, Evaluator, eval_model_multithread
from atari import AtariPlayer
from common import Evaluator, eval_model_multithread, play_n_episodes
from common import FrameStack, WarpFrame, FireResetEnv
from expreplay import ExpReplay
from atari import AtariPlayer
BATCH_SIZE = 64
IMAGE_SIZE = (84, 84)
......@@ -37,7 +37,7 @@ GAMMA = 0.99
MEMORY_SIZE = 1e6
# will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
INIT_MEMORY_SIZE = 5e4
INIT_MEMORY_SIZE = MEMORY_SIZE // 20
STEPS_PER_EPOCH = 10000 // UPDATE_FREQ * 10 # each epoch is 100k played frames
EVAL_EPISODE = 50
......@@ -47,17 +47,14 @@ METHOD = None
def get_player(viz=False, train=False):
pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT,
image_shape=IMAGE_SIZE[::-1], viz=viz, live_lost_as_eoe=train)
env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz,
live_lost_as_eoe=train, max_num_frames=30000)
env = FireResetEnv(env)
env = WarpFrame(env, IMAGE_SIZE)
if not train:
# create a new axis to stack history on
pl = MapPlayerState(pl, lambda im: im[:, :, np.newaxis])
# in training, history is taken care of in expreplay buffer
pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 30000)
return pl
env = FrameStack(env, FRAME_HISTORY)
return env
class Model(DQNModel):
......@@ -149,20 +146,20 @@ if __name__ == '__main__':
ROM_FILE = args.rom
METHOD = args.algo
# set num_actions
NUM_ACTIONS = AtariPlayer(ROM_FILE).get_action_space().num_actions()
NUM_ACTIONS = AtariPlayer(ROM_FILE).action_space.n
logger.info("ROM: {}, Num Actions: {}".format(ROM_FILE, NUM_ACTIONS))
if args.task != 'train':
assert args.load is not None
cfg = PredictConfig(
pred = OfflinePredictor(PredictConfig(
model=Model(),
session_init=get_model_loader(args.load),
input_names=['state'],
output_names=['Qvalue'])
output_names=['Qvalue']))
if args.task == 'play':
play_model(cfg, get_player(viz=0.01))
play_n_episodes(get_player(viz=0.01), pred, 100)
elif args.task == 'eval':
eval_model_multithread(cfg, EVAL_EPISODE, get_player)
eval_model_multithread(pred, EVAL_EPISODE, get_player)
else:
logger.set_logger_dir(
os.path.join('train_log', 'DQN-{}'.format(
......
......@@ -7,7 +7,6 @@ import numpy as np
import time
import os
import cv2
from collections import deque
import threading
import six
from six.moves import range
......@@ -16,7 +15,9 @@ from tensorpack.utils.utils import get_rng, execute_only_once
from tensorpack.utils.fs import get_dataset_path
from tensorpack.utils.stats import StatCounter
from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace
import gym
from gym import spaces
from gym.envs.atari.atari_env import ACTION_MEANING
from ale_python_interface import ALEInterface
......@@ -26,27 +27,29 @@ ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
_ALE_LOCK = threading.Lock()
class AtariPlayer(RLEnvironment):
class AtariPlayer(gym.Env):
"""
A wrapper for atari emulator.
Will automatically restart when a real episode ends (isOver might be just
lost of lives but not game over).
A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings.
Info:
score: the accumulated reward in the current game
gameOver: True when the current game is Over
"""
def __init__(self, rom_file, viz=0, height_range=(None, None),
frame_skip=4, image_shape=(84, 84), nullop_start=30,
live_lost_as_eoe=True):
def __init__(self, rom_file, viz=0,
frame_skip=4, nullop_start=30,
live_lost_as_eoe=True, max_num_frames=0):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames and repeat the action
:param image_shape: (w, h)
:param height_range: (h1, h2) to cut
:param viz: visualization to be done.
Set to 0 to disable.
Set to a positive number to be the delay between frames to show.
Set to a string to be a directory to store frames.
:param nullop_start: start with random number of null ops
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training.
Args:
rom_file: path to the rom
frame_skip: skip every k frames and repeat the action
viz: visualization to be done.
Set to 0 to disable.
Set to a positive number to be the delay between frames to show.
Set to a string to be a directory to store frames.
nullop_start: start with random number of null ops.
live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
max_num_frames: maximum number of frames per episode.
"""
super(AtariPlayer, self).__init__()
if not os.path.isfile(rom_file) and '/' not in rom_file:
......@@ -65,6 +68,7 @@ class AtariPlayer(RLEnvironment):
self.ale = ALEInterface()
self.rng = get_rng(self)
self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
self.ale.setBool(b"showinfo", False)
self.ale.setInt(b"frame_skip", 1)
......@@ -92,11 +96,16 @@ class AtariPlayer(RLEnvironment):
self.live_lost_as_eoe = live_lost_as_eoe
self.frame_skip = frame_skip
self.nullop_start = nullop_start
self.height_range = height_range
self.image_shape = image_shape
self.current_episode_score = StatCounter()
self.restart_episode()
self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width))
self._restart_episode()
def get_action_meanings(self):
return [ACTION_MEANING[i] for i in self.actions]
def _grab_raw_image(self):
"""
......@@ -105,7 +114,7 @@ class AtariPlayer(RLEnvironment):
m = self.ale.getScreenRGB()
return m.reshape((self.height, self.width, 3))
def current_state(self):
def _current_state(self):
"""
:returns: a gray-scale (h, w) uint8 image
"""
......@@ -116,19 +125,12 @@ class AtariPlayer(RLEnvironment):
if isinstance(self.viz, float):
cv2.imshow(self.windowname, ret)
time.sleep(self.viz)
ret = ret[self.height_range[0]:self.height_range[1], :].astype('float32')
ret = ret.astype('float32')
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.resize(ret, self.image_shape)
return ret.astype('uint8') # to save some memory
def get_action_space(self):
return DiscreteActionSpace(len(self.actions))
def finish_episode(self):
self.stats['score'].append(self.current_episode_score.sum)
def restart_episode(self):
def _restart_episode(self):
self.current_episode_score.reset()
with _ALE_LOCK:
self.ale.reset_game()
......@@ -141,11 +143,12 @@ class AtariPlayer(RLEnvironment):
self.last_raw_screen = self._grab_raw_image()
self.ale.act(0)
def action(self, act):
"""
:param act: an index of the action
:returns: (reward, isOver)
"""
def _reset(self):
if self.ale.game_over():
self._restart_episode()
return self._current_state()
def _step(self, act):
oldlives = self.ale.lives()
r = 0
for k in range(self.frame_skip):
......@@ -158,55 +161,24 @@ class AtariPlayer(RLEnvironment):
break
self.current_episode_score.feed(r)
isOver = self.ale.game_over()
trueIsOver = isOver = self.ale.game_over()
if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
if isOver:
self.finish_episode()
if self.ale.game_over():
self.restart_episode()
return (r, isOver)
info = {'score': self.current_episode_score.sum, 'gameOver': trueIsOver}
return self._current_state(), r, isOver, info
if __name__ == '__main__':
import sys
def benchmark():
a = AtariPlayer(sys.argv[1], viz=False, height_range=(28, -8))
num = a.get_action_space().num_actions()
rng = get_rng(num)
start = time.time()
cnt = 0
while True:
act = rng.choice(range(num))
r, o = a.action(act)
a.current_state()
cnt += 1
if cnt == 5000:
break
print(time.time() - start)
if len(sys.argv) == 3 and sys.argv[2] == 'benchmark':
import threading
import multiprocessing
for k in range(3):
# th = multiprocessing.Process(target=benchmark)
th = threading.Thread(target=benchmark)
th.start()
time.sleep(0.02)
benchmark()
else:
a = AtariPlayer(sys.argv[1],
viz=0.03, height_range=(28, -8))
num = a.get_action_space().num_actions()
rng = get_rng(num)
import time
while True:
# im = a.grab_image()
# cv2.imshow(a.romname, im)
act = rng.choice(range(num))
print(act)
r, o = a.action(act)
a.current_state()
# time.sleep(0.1)
print(r, o)
a = AtariPlayer(sys.argv[1], viz=0.03)
num = a.action_space.n
rng = get_rng(num)
while True:
act = rng.choice(range(num))
state, reward, isOver, info = a.step(act)
if isOver:
print(info)
a.reset()
print("Reward:", reward)
......@@ -7,35 +7,56 @@ import time
import threading
import multiprocessing
import numpy as np
import cv2
from collections import deque
from tqdm import tqdm
from six.moves import queue
from tensorpack import *
from tensorpack.utils.concurrency import *
from tensorpack.utils.stats import *
import gym
from gym import spaces
from tensorpack.utils.concurrency import StoppableThread, ShareSessionThread
from tensorpack.callbacks import Triggerable
from tensorpack.utils import logger
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_tqdm_kwargs
def play_one_episode(player, func, verbose=False):
def f(s):
spc = player.get_action_space()
def play_one_episode(env, func, render=False):
def predict(s):
"""
Map from observation to action, with 0.001 greedy.
"""
act = func([[s]])[0][0].argmax()
if random.random() < 0.001:
spc = env.action_space
act = spc.sample()
if verbose:
print(act)
return act
return np.mean(player.play_one_episode(f))
def play_model(cfg, player):
predfunc = OfflinePredictor(cfg)
ob = env.reset()
sum_r = 0
while True:
score = play_one_episode(player, predfunc)
print("Total:", score)
act = predict(ob)
ob, r, isOver, info = env.step(act)
if render:
env.render()
sum_r += r
if isOver:
return sum_r
def play_n_episodes(player, predfunc, nr, render=False):
logger.info("Start Playing ... ")
for k in range(nr):
score = play_one_episode(player, predfunc, render=render)
print("{}/{}, score={}".format(k, nr, score))
def eval_with_funcs(predictors, nr_eval, get_player_fn):
"""
Args:
predictors ([PredictorBase])
"""
class Worker(StoppableThread, ShareSessionThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
......@@ -85,10 +106,14 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
return (0, 0)
def eval_model_multithread(cfg, nr_eval, get_player_fn):
func = OfflinePredictor(cfg)
def eval_model_multithread(pred, nr_eval, get_player_fn):
"""
Args:
pred (OfflinePredictor): state -> Qvalue
"""
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC, nr_eval, get_player_fn)
with pred.sess.as_default():
mean, max = eval_with_funcs([pred] * NR_PROC, nr_eval, get_player_fn)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
......@@ -115,10 +140,103 @@ class Evaluator(Triggerable):
self.trainer.monitors.put_scalar('max_score', max)
def play_n_episodes(player, predfunc, nr):
logger.info("Start evaluation: ")
for k in range(nr):
if k != 0:
player.restart_episode()
score = play_one_episode(player, predfunc)
print("{}/{}, score={}".format(k, nr, score))
"""
------------------------------------------------------------------------------
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, shape):
gym.ObservationWrapper.__init__(self, env)
self.shape = shape
obs = env.observation_space
assert isinstance(obs, spaces.Box)
chan = 1 if len(obs.shape) == 2 else obs.shape[2]
shape3d = shape if chan == 1 else shape + (chan,)
self.observation_space = spaces.Box(low=0, high=255, shape=shape3d)
def _observation(self, obs):
return cv2.resize(obs, self.shape)
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self._base_dim = len(shp)
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k))
def _reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob))
self.frames.append(ob)
return self._observation()
def _step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._observation(), reward, done, info
def _observation(self):
assert len(self.frames) == self.k
if self._base_dim == 2:
return np.stack(self.frames, axis=-1)
else:
return np.concatenate(self.frames, axis=2)
class _FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
def FireResetEnv(env):
if isinstance(env, gym.Wrapper):
baseenv = env.unwrapped
else:
baseenv = env
if 'FIRE' in baseenv.get_action_meanings():
return _FireResetEnv(env)
return env
class LimitLength(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self.k = k
def _reset(self):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob = self.env.reset()
self.cnt = 0
return ob
def _step(self, action):
ob, r, done, info = self.env.step(action)
self.cnt += 1
if self.cnt == self.k:
done = True
return ob, r, done, info
......@@ -13,6 +13,7 @@ from six.moves import queue, range
from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger
from tensorpack.utils.utils import get_tqdm, get_rng
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback
......@@ -142,7 +143,7 @@ class ExpReplay(DataFlow, Callback):
if k != 'self':
setattr(self, k, v)
self.exploration = init_exploration
self.num_actions = player.get_action_space().num_actions()
self.num_actions = player.action_space.n
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.rng = get_rng(self)
......@@ -152,6 +153,8 @@ class ExpReplay(DataFlow, Callback):
self._populate_job_queue = queue.Queue(maxsize=5)
self.mem = ReplayMemory(memory_size, state_shape, history_len)
self._current_ob = self.player.reset()
self._player_scores = StatCounter()
def get_simulator_thread(self):
# spawn a separate thread to run policy
......@@ -186,7 +189,7 @@ class ExpReplay(DataFlow, Callback):
def _populate_exp(self):
""" populate a transition by epsilon-greedy"""
old_s = self.player.current_state()
old_s = self._current_ob
if self.rng.rand() <= self.exploration or (len(self.mem) <= self.history_len):
act = self.rng.choice(range(self.num_actions))
else:
......@@ -198,7 +201,11 @@ class ExpReplay(DataFlow, Callback):
# assume batched network
q_values = self.predictor([[history]])[0][0] # this is the bottleneck
act = np.argmax(q_values)
reward, isOver = self.player.action(act)
self._current_ob, reward, isOver, info = self.player.step(act)
if isOver:
if info['gameOver']: # only record score when a whole game is over (not when an episode is over)
self._player_scores.feed(info['score'])
self.player.reset()
self.mem.append(Experience(old_s, act, reward, isOver))
def _debug_sample(self, sample):
......@@ -245,17 +252,15 @@ class ExpReplay(DataFlow, Callback):
self._simulator_th = self.get_simulator_thread()
self._simulator_th.start()
def _trigger_epoch(self):
# log player statistics in training
stats = self.player.stats
for k, v in six.iteritems(stats):
try:
mean, max = np.mean(v), np.max(v)
self.trainer.monitors.put_scalar('expreplay/mean_' + k, mean)
self.trainer.monitors.put_scalar('expreplay/max_' + k, max)
except:
logger.exception("Cannot log training scores.")
self.player.reset_stat()
def _trigger(self):
v = self._player_scores
try:
mean, max = v.average, v.max
self.trainer.monitors.put_scalar('expreplay/mean_score', mean)
self.trainer.monitors.put_scalar('expreplay/max_score', max)
except:
logger.exception("Cannot log training scores.")
v.reset()
if __name__ == '__main__':
......
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import iter_modules
from ..utils.develop import log_deprecated
import os
import os.path
......@@ -13,6 +14,8 @@ __all__ = []
This module should be removed in the future.
"""
log_deprecated("tensorpack.RL", "Please use gym or other APIs instead!", "2017-12-31")
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment