Commit 4c7348c3 authored by Yuxin Wu's avatar Yuxin Wu

change how expreplay works...

parent 0c5e39eb
......@@ -3,8 +3,8 @@
# File: DQN.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy as np
import tensorflow as tf
import os, sys, re
import random
import argparse
......@@ -22,7 +22,7 @@ from tensorpack.predict import PredictConfig, get_predict_func, ParallelPredictW
from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import *
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
from tensorpack.dataflow.dataset import AtariPlayer
from tensorpack.dataflow.RL import ExpReplay
"""
......@@ -36,13 +36,13 @@ IMAGE_SIZE = 84
NUM_ACTIONS = None
FRAME_HISTORY = 4
ACTION_REPEAT = 3
HEIGHT_RANGE = (36, 204) # for breakout
# HEIGHT_RANGE = (28, -8) # for pong
#HEIGHT_RANGE = (36, 204) # for breakout
HEIGHT_RANGE = (28, -8) # for pong
GAMMA = 0.99
BATCH_SIZE = 32
INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL = 0.0025
EXPLORATION_EPOCH_ANNEAL = 0.0020
END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6
......@@ -62,15 +62,20 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image, is_training):
""" image: [0,255]"""
image = image / 128.0 - 1
with argscope(Conv2D, nl=tf.nn.relu, use_bias=True):
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5, stride=2)
l = Conv2D('conv1', l, out_channel=32, kernel_shape=5, stride=2)
l = Conv2D('conv2', l, out_channel=64, kernel_shape=4, stride=2)
image = image / 255.0
with argscope(Conv2D, nl=PReLU.f, use_bias=True):
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5, stride=1)
l = MaxPooling('pool0', l, 2)
l = Conv2D('conv1', l, out_channel=32, kernel_shape=5, stride=1)
l = MaxPooling('pool1', l, 2)
l = Conv2D('conv2', l, out_channel=64, kernel_shape=4)
l = MaxPooling('pool2', l, 2)
l = Conv2D('conv3', l, out_channel=64, kernel_shape=3)
l = MaxPooling('pool3', l, 2)
l = Conv2D('conv4', l, out_channel=64, kernel_shape=3)
l = FullyConnected('fc0', l, 512)
l = FullyConnected('fct', l, out_dim=NUM_ACTIONS, nl=tf.identity, summary_activation=False)
l = FullyConnected('fc0', l, 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))
l = FullyConnected('fct', l, out_dim=NUM_ACTIONS, nl=tf.identity)
return l
def _build_graph(self, inputs, is_training):
......@@ -136,14 +141,14 @@ def play_one_episode(player, func, verbose=False):
tot_reward = 0
que = deque(maxlen=30)
while True:
s = player.current_state()
s = player.current_state() # XXX
outputs = func([[s]])
action_value = outputs[0][0]
act = action_value.argmax()
if verbose:
print action_value, act
if random.random() < 0.01:
act = random.choice(range(player.driver.get_num_actions()))
act = random.choice(range(NUM_ACTIONS))
if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen:
act = 1 # hack, avoid stuck
......@@ -156,10 +161,11 @@ def play_one_episode(player, func, verbose=False):
return tot_reward
def play_model(model_path, romfile):
player = AtariPlayer(AtariDriver(romfile, viz=0.01, height_range=HEIGHT_RANGE),
action_repeat=ACTION_REPEAT)
player = HistoryFramePlayer(AtariPlayer(
romfile, viz=0.01, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT), FRAME_HISTORY)
global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions()
NUM_ACTIONS = player.player.get_num_actions()
M = Model()
cfg = PredictConfig(
......@@ -186,10 +192,11 @@ def eval_model_multiprocess(model_path, romfile):
self.outq = outqueue
def run(self):
player = AtariPlayer(AtariDriver(romfile, viz=0, height_range=HEIGHT_RANGE),
action_repeat=ACTION_REPEAT)
player = HistoryFramePlayer(AtariPlayer(
romfile, viz=0, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT), FRAME_HISTORY)
global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions()
NUM_ACTIONS = player.player.get_num_actions()
self._init_runtime()
while True:
score = play_one_episode(player, self.func)
......@@ -226,15 +233,15 @@ def get_config(romfile):
os.path.join('train_log', basename[:basename.rfind('.')]))
M = Model()
driver = AtariDriver(romfile, height_range=HEIGHT_RANGE)
player = AtariPlayer(
romfile, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT)
global NUM_ACTIONS
NUM_ACTIONS = driver.get_num_actions()
NUM_ACTIONS = player.get_num_actions()
dataset_train = ExpReplay(
predictor=current_predictor,
player=AtariPlayer(
driver, hist_len=FRAME_HISTORY,
action_repeat=ACTION_REPEAT),
player=player,
num_actions=NUM_ACTIONS,
memory_size=MEMORY_SIZE,
batch_size=BATCH_SIZE,
......@@ -242,22 +249,23 @@ def get_config(romfile):
exploration=INIT_EXPLORATION,
end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
reward_clip=(-1, 2))
reward_clip=(-1, 1),
history_len=FRAME_HISTORY)
lr = tf.Variable(0.0025, trainable=False, name='learning_rate')
lr = tf.Variable(0.00025, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
class Evaluator(Callback):
def _trigger_epoch(self):
logger.info("Evaluating...")
output = subprocess.check_output(
"""{} --task eval --rom {} --load {} 2>&1 | grep Average""".format(
"""CUDA_VISIBLE_DEVICES= {} --task eval --rom {} --load {} 2>&1 | grep Average""".format(
sys.argv[0], romfile, os.path.join(logger.LOG_DIR, 'checkpoint')), shell=True)
output = output.strip()
output = output[output.find(']')+1:]
mean, maximum = re.findall('[0-9\.]+', output)
self.trainer.write_scalar_summary('eval_mean_score', mean)
self.trainer.write_scalar_summary('eval_max_score', maximum)
mean, maximum = re.findall('[0-9\.\-]+', output)[-2:]
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', maximum)
return TrainConfig(
dataset=dataset_train,
......@@ -269,7 +277,7 @@ def get_config(romfile):
HumanHyperParamSetter((dataset_train, 'exploration'), 'hyper.txt'),
TargetNetworkUpdator(M),
dataset_train,
PeriodicCallback(Evaluator(), 1),
PeriodicCallback(Evaluator(), 2),
]),
session_config=get_default_sess_config(0.5),
model=M,
......
......@@ -19,10 +19,10 @@ from tensorpack.callbacks.base import Callback
Implement RL-related data preprocessing
"""
__all__ = ['ExpReplay', 'RLEnvironment', 'NaiveRLEnvironment']
__all__ = ['ExpReplay', 'RLEnvironment', 'NaiveRLEnvironment', 'HistoryFramePlayer']
Experience = namedtuple('Experience',
['state', 'action', 'reward', 'next', 'isOver'])
['state', 'action', 'reward', 'isOver'])
class RLEnvironment(object):
__meta__ = ABCMeta
......@@ -65,6 +65,49 @@ class NaiveRLEnvironment(RLEnvironment):
self.k = act
return (self.k, self.k > 10)
class ProxyPlayer(RLEnvironment):
def __init__(self, player):
self.player = player
def get_stat(self):
return self.player.get_stat()
def reset_stat(self):
self.player.reset_stat()
def current_state(self):
return self.player.current_state()
def action(self, act):
return self.player.action(act)
class HistoryFramePlayer(ProxyPlayer):
def __init__(self, player, hist_len):
super(HistoryFramePlayer, self).__init__(player)
self.history = deque(maxlen=hist_len)
s = self.player.current_state()
self.history.append(s)
def current_state(self):
assert len(self.history) != 0
diff_len = self.history.maxlen - len(self.history)
if diff_len == 0:
return np.concatenate(self.history, axis=2)
zeros = [np.zeros_like(self.history[0]) for k in range(diff_len)]
for k in self.history:
zeros.append(k)
return np.concatenate(zeros, axis=2)
def action(self, act):
r, isOver = self.player.action(act)
s = self.player.current_state()
self.history.append(s)
if isOver: # s would be a new episode
self.history.clear()
self.history.append(s)
return (r, isOver)
class ExpReplay(DataFlow, Callback):
"""
......@@ -82,11 +125,15 @@ class ExpReplay(DataFlow, Callback):
end_exploration=0.1,
exploration_epoch_anneal=0.002,
reward_clip=None,
new_experience_per_step=1
new_experience_per_step=1,
history_len=1
):
"""
:param predictor: callabale. called with a state, return a distribution
:param predictor: a callabale calling the up-to-date network.
called with a state, return a distribution
:param player: a `RLEnvironment`
:param num_actions: int
:param history_len: length of history frames to concat. zero-filled initial frames
"""
for k, v in locals().items():
if k != 'self':
......@@ -106,51 +153,83 @@ class ExpReplay(DataFlow, Callback):
raise RuntimeError("Don't run me in multiple processes")
def _populate_exp(self):
p = self.rng.rand()
old_s = self.player.current_state()
if p <= self.exploration:
if self.rng.rand() <= self.exploration:
act = self.rng.choice(range(self.num_actions))
else:
act = np.argmax(self.predictor(old_s)) # TODO race condition in session?
# build a history state
ss = [old_s]
for k in range(1, self.history_len):
hist_exp = self.mem[-k]
if hist_exp.isOver:
ss.append(np.zeros_like(ss[0]))
else:
ss.append(hist_exp.state)
ss = np.concatenate(ss, axis=2)
act = np.argmax(self.predictor(ss))
reward, isOver = self.player.action(act)
if self.reward_clip:
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
s = self.player.current_state()
#def view_state(state):
#""" for debug state representation"""
#r = np.concatenate([state[:,:,k] for k in range(state.shape[2])], axis=1)
#print r.shape
#cv2.imshow("state", r)
#cv2.waitKey()
#print act, reward
#view_state(s)
# s is considered useless if isOver==True
self.mem.append(Experience(old_s, act, reward, s, isOver))
self.mem.append(Experience(old_s, act, reward, isOver))
def get_data(self):
# new s is considered useless if isOver==True
while True:
idxs = self.rng.randint(len(self.mem), size=self.batch_size)
batch_exp = [self.mem[k] for k in idxs]
batch_exp = [self.sample_one() for _ in range(self.batch_size)]
def view_state(state, next_state):
""" for debug state representation"""
r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
print r.shape
r = np.concatenate([r, r2], axis=0)
cv2.imshow("state", r)
cv2.waitKey()
exp = batch_exp[0]
print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
view_state(exp[0], exp[1])
yield self._process_batch(batch_exp)
for _ in range(self.new_experience_per_step):
self._populate_exp()
def sample_one(self):
""" return the transition tuple for
[idx, idx+history_len] -> [idx+1, idx+1+history_len]
it's the transition from state idx+history_len-1 to state idx+history_len
"""
# look for a state to start with
# when x.isOver==True, (x+1).state is of a different episode
idx = self.rng.randint(len(self.mem) - self.history_len - 1)
start_idx = idx + self.history_len - 1
def concat(idx):
v = [self.mem[x].state for x in range(idx, idx+self.history_len)]
return np.concatenate(v, axis=2)
state = concat(idx)
next_state = concat(idx + 1)
reward = self.mem[start_idx].reward
action = self.mem[start_idx].action
isOver = self.mem[start_idx].isOver
# zero-fill state before starting
zero_fill = False
for k in range(1, self.history_len):
if self.mem[start_idx-k].isOver:
zero_fill = True
if zero_fill:
state[:,:,-k-1] = 0
if k + 2 <= self.history_len:
next_state[:,:,-k-2] = 0
return (state, next_state, reward, action, isOver)
def _process_batch(self, batch_exp):
state_shape = batch_exp[0].state.shape
state = np.zeros((self.batch_size, ) + state_shape, dtype='float32')
next_state = np.zeros((self.batch_size, ) + state_shape, dtype='float32')
reward = np.zeros((self.batch_size,), dtype='float32')
action = np.zeros((self.batch_size,), dtype='int32')
isOver = np.zeros((self.batch_size,), dtype='bool')
for idx, b in enumerate(batch_exp):
state[idx] = b.state
action[idx] = b.action
next_state[idx] = b.next
reward[idx] = b.reward
isOver[idx] = b.isOver
state = np.array([e[0] for e in batch_exp])
next_state = np.array([e[1] for e in batch_exp])
reward = np.array([e[2] for e in batch_exp])
action = np.array([e[3] for e in batch_exp])
isOver = np.array([e[4] for e in batch_exp])
return [state, action, reward, next_state, isOver]
# Callback-related:
......@@ -170,12 +249,16 @@ class ExpReplay(DataFlow, Callback):
if __name__ == '__main__':
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
from tensorpack.dataflow.dataset import AtariPlayer
import sys
predictor = lambda x: np.array([1,1,1,1])
predictor.initialized = False
E = AtariExpReplay(predictor, predictor,
AtariPlayer(AtariDriver('../../space_invaders.bin', viz=0.01)),
populate_size=1000)
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=20)
E = ExpReplay(predictor,
player=player,
num_actions=player.get_num_actions(),
populate_size=1001,
history_len=4)
E.init_memory()
for k in E.get_data():
......
......@@ -9,6 +9,7 @@ import os
import cv2
from collections import deque
from ...utils import get_rng, logger
from ...utils.stat import StatCounter
from ..RL import RLEnvironment
try:
......@@ -16,23 +17,27 @@ try:
except ImportError:
logger.warn("Cannot import ale_python_interface, Atari won't be available.")
__all__ = ['AtariDriver', 'AtariPlayer']
__all__ = ['AtariPlayer']
class AtariDriver(RLEnvironment):
class AtariPlayer(RLEnvironment):
"""
A wrapper for atari emulator.
"""
def __init__(self, rom_file, viz=0, height_range=(None,None)):
def __init__(self, rom_file, viz=0, height_range=(None,None),
frame_skip=4, image_shape=(84, 84)):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
:param image_shape: (w, h)
:param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable
"""
super(AtariPlayer, self).__init__()
self.ale = ALEInterface()
self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(self.rng.randint(0, 1000)))
self.ale.setInt("frame_skip", 1)
self.ale.setInt("frame_skip", frame_skip)
self.ale.setBool('color_averaging', True)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
......@@ -45,9 +50,11 @@ class AtariDriver(RLEnvironment):
if self.viz and isinstance(self.viz, float):
cv2.startWindowThread()
cv2.namedWindow(self.romname)
self.framenum = 0
self.height_range = height_range
self.framenum = 0
self.image_shape = image_shape
self.current_episode_score = StatCounter()
self._reset()
......@@ -61,9 +68,9 @@ class AtariDriver(RLEnvironment):
def current_state(self):
"""
:returns: a gray-scale image, max-pooled over the last frame.
:returns: a gray-scale (h, w, 1) image
"""
now = self._grab_raw_image()
ret = self._grab_raw_image()
if self.viz:
if isinstance(self.viz, float):
cv2.imshow(self.romname, ret)
......@@ -73,6 +80,8 @@ class AtariDriver(RLEnvironment):
self.framenum += 1
ret = ret[self.height_range[0]:self.height_range[1],:]
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
ret = cv2.resize(ret, self.image_shape)
ret = np.expand_dims(ret, axis=2)
return ret
def get_num_actions(self):
......@@ -82,6 +91,7 @@ class AtariDriver(RLEnvironment):
return len(self.actions)
def _reset(self):
self.current_episode_score.reset()
self.ale.reset_game()
def action(self, act):
......@@ -90,80 +100,13 @@ class AtariDriver(RLEnvironment):
:returns: (reward, isOver)
"""
r = self.ale.act(self.actions[act])
self.current_episode_score.feed(r)
isOver = self.ale.game_over()
if isOver:
self.stats['score'].append(self.current_episode_score.sum)
self._reset()
return (r, isOver)
class AtariPlayer(RLEnvironment):
""" An Atari game player with limited memory and FPS"""
def __init__(self, driver, hist_len=4, action_repeat=4, image_shape=(84,84)):
"""
:param driver: an `AtariDriver` instance.
:param hist_len: history(memory) length
:param action_repeat: repeat each action `action_repeat` times and skip those frames
:param image_shape: the shape of the observed image
"""
super(AtariPlayer, self).__init__()
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.last_act = 0
self.frames = deque(maxlen=hist_len)
self.current_accum_score = 0
self.restart()
def restart(self):
"""
Restart the game and populate frames with the beginning frame
"""
self.current_accum_score = 0
self.frames.clear()
s = self.driver.current_state()
s = cv2.resize(s, self.image_shape)
for _ in range(self.hist_len):
self.frames.append(s)
def current_state(self):
"""
Return a current state of shape `image_shape + (hist_len,)`
"""
return self._build_state()
def action(self, act):
"""
Perform an action
:param act: index of the action
:returns: (reward, isOver)
"""
self.last_act = act
return self._observe()
def _build_state(self):
assert len(self.frames) == self.hist_len
m = np.array(self.frames)
m = m.transpose([1,2,0])
return m
def _observe(self):
""" if isOver==True, current_state will return the new episode
"""
totr = 0
for k in range(self.action_repeat):
r, isOver = self.driver.action(self.last_act)
s = self.driver.current_state()
totr += r
if isOver:
break
s = cv2.resize(s, self.image_shape)
self.current_accum_score += totr
self.frames.append(s)
if isOver:
self.stats['score'].append(self.current_accum_score)
self.restart()
return (totr, isOver)
def get_stat(self):
try:
return {'avg_score': np.mean(self.stats['score']),
......@@ -173,7 +116,8 @@ class AtariPlayer(RLEnvironment):
if __name__ == '__main__':
import sys
a = AtariDriver(sys.argv[1], viz=0.01, height_range=(28,-8))
a = AtariPlayer(sys.argv[1],
viz=0.01, height_range=(28,-8))
num = a.get_num_actions()
rng = get_rng(num)
import time
......@@ -183,6 +127,7 @@ if __name__ == '__main__':
act = rng.choice(range(num))
print act
r, o = a.action(act)
a.current_state()
#time.sleep(0.1)
print(r, o)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment