Commit 49a21a29 authored by Yuxin Wu's avatar Yuxin Wu

DQN release ready

parent 33353f33
......@@ -28,21 +28,15 @@ from tensorpack.callbacks import *
from tensorpack.RL import *
"""
Implement DQN in:
Human-level Control Through Deep Reinforcement Learning
for atari games. Use the variants in:
Deep Reinforcement Learning with Double Q-learning.
"""
BATCH_SIZE = 32
IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4
ACTION_REPEAT = 3
ACTION_REPEAT = 4
HEIGHT_RANGE = (36, 204) # for breakout
#HEIGHT_RANGE = (28, -8) # for pong
CHANNEL = FRAME_HISTORY
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
#HEIGHT_RANGE = (28, -8) # for pong
GAMMA = 0.99
INIT_EXPLORATION = 1
......@@ -52,7 +46,7 @@ END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6
INIT_MEMORY_SIZE = 50000
STEP_PER_EPOCH = 10000
EVAL_EPISODE = 100
EVAL_EPISODE = 50
NUM_ACTIONS = None
ROM_FILE = None
......@@ -63,10 +57,10 @@ def get_player(viz=False, train=False):
live_lost_as_eoe=train)
global NUM_ACTIONS
NUM_ACTIONS = pl.get_num_actions()
if not train:
pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 20000)
return pl
class Model(ModelDesc):
......@@ -81,7 +75,7 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image, is_training):
""" image: [0,255]"""
image = image / 255.0
with argscope(Conv2D, nl=tf.nn.relu, use_bias=True):
with argscope(Conv2D, nl=PReLU.f, use_bias=True):
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5, stride=1)
l = MaxPooling('pool0', l, 2)
l = Conv2D('conv1', l, out_channel=32, kernel_shape=5, stride=1)
......@@ -158,7 +152,11 @@ def play_one_episode(player, func, verbose=False):
return np.mean(player.play_one_episode(f))
def play_model(model_path):
player = get_player(0.013)
import uuid
dirname = 'record' + str(uuid.uuid1())[:6]
print dirname
os.mkdir(dirname)
player = get_player(viz=dirname)
cfg = PredictConfig(
model=Model(),
input_data_mapping=[0],
......@@ -168,8 +166,9 @@ def play_model(model_path):
while True:
score = play_one_episode(player, predfunc)
print("Total:", score)
break
def eval_with_funcs(predict_funcs):
def eval_with_funcs(predict_funcs, nr_eval=EVAL_EPISODE):
class Worker(StoppableThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
......@@ -181,7 +180,7 @@ def eval_with_funcs(predict_funcs):
score = play_one_episode(player, self.func)
self.queue_put_stoppable(self.q, score)
q = queue.Queue(maxsize=3)
q = queue.Queue(maxsize=2)
threads = [Worker(f, q) for f in predict_funcs]
for k in threads:
......@@ -189,10 +188,11 @@ def eval_with_funcs(predict_funcs):
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
try:
for _ in tqdm(range(EVAL_EPISODE)):
for _ in tqdm(range(nr_eval)):
r = q.get()
stat.feed(r)
finally:
logger.info("Waiting for all the workers to finish the last run...")
for k in threads: k.stop()
for k in threads: k.join()
return (stat.average, stat.max)
......@@ -214,9 +214,14 @@ class Evaluator(Callback):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func(
['state'], ['fct/output'])] * NR_PROC
self.eval_episode = EVAL_EPISODE
def _trigger_epoch(self):
mean, max = eval_with_funcs(self.pred_funcs)
t = time.time()
mean, max = eval_with_funcs(self.pred_funcs, nr_eval=self.eval_episode)
t = time.time() - t
if t > 8 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.89)
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', max)
......@@ -240,7 +245,7 @@ def get_config():
reward_clip=(-1, 1),
history_len=FRAME_HISTORY)
lr = tf.Variable(0.00025, trainable=False, name='learning_rate')
lr = tf.Variable(0.0004, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
......
Implement DQN in:
**Human-level Control Through Deep Reinforcement Learning**
and Double-DQN in:
**Deep Reinforcement Learning with Double Q-learning**
To run:
```
./DQN.py --rom breakout.rom --gpu 0
```
A demo trained with Double-DQN is available at [youtube](https://youtu.be/o21mddZtE5Y)
......@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import time
import os
import time, os
import cv2
from collections import deque
import six
from six.moves import range
from ..utils import get_rng, logger, memoized
from ..utils.stat import StatCounter
......@@ -37,7 +37,10 @@ class AtariPlayer(RLEnvironment):
:param frame_skip: skip every k frames and repeat the action
:param image_shape: (w, h)
:param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable
:param viz: visualization to be done.
Set to 0 to disable.
Set to a positive number to be the delay between frames to show.
Set to a string to be a directory to store frames.
:param nullop_start: start with random number of null ops
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training.
"""
......@@ -57,18 +60,24 @@ class AtariPlayer(RLEnvironment):
self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check
self.ale.setFloat('repeat_action_probability', 0.0)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
# viz setup
if isinstance(viz, six.string_types):
assert os.path.isdir(viz), viz
self.ale.setString('record_screen_dir', viz)
viz = 0
if isinstance(viz, int):
viz = float(viz)
self.viz = viz
self.romname = os.path.basename(rom_file)
if self.viz and isinstance(self.viz, float):
self.windowname = os.path.basename(rom_file)
cv2.startWindowThread()
cv2.namedWindow(self.romname)
cv2.namedWindow(self.windowname)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
self.live_lost_as_eoe = live_lost_as_eoe
self.frame_skip = frame_skip
......@@ -95,7 +104,7 @@ class AtariPlayer(RLEnvironment):
ret = np.maximum(ret, self.last_raw_screen)
if self.viz:
if isinstance(self.viz, float):
cv2.imshow(self.romname, ret)
cv2.imshow(self.windowname, ret)
time.sleep(self.viz)
ret = ret[self.height_range[0]:self.height_range[1],:]
# 0.299,0.587.0.114. same as rgb2y in torch/image
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment