Commit 49a21a29 authored by Yuxin Wu's avatar Yuxin Wu

DQN release ready

parent 33353f33
...@@ -28,21 +28,15 @@ from tensorpack.callbacks import * ...@@ -28,21 +28,15 @@ from tensorpack.callbacks import *
from tensorpack.RL import * from tensorpack.RL import *
"""
Implement DQN in:
Human-level Control Through Deep Reinforcement Learning
for atari games. Use the variants in:
Deep Reinforcement Learning with Double Q-learning.
"""
BATCH_SIZE = 32 BATCH_SIZE = 32
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 3 ACTION_REPEAT = 4
HEIGHT_RANGE = (36, 204) # for breakout HEIGHT_RANGE = (36, 204) # for breakout
#HEIGHT_RANGE = (28, -8) # for pong
CHANNEL = FRAME_HISTORY CHANNEL = FRAME_HISTORY
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,) IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
#HEIGHT_RANGE = (28, -8) # for pong
GAMMA = 0.99 GAMMA = 0.99
INIT_EXPLORATION = 1 INIT_EXPLORATION = 1
...@@ -52,7 +46,7 @@ END_EXPLORATION = 0.1 ...@@ -52,7 +46,7 @@ END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
INIT_MEMORY_SIZE = 50000 INIT_MEMORY_SIZE = 50000
STEP_PER_EPOCH = 10000 STEP_PER_EPOCH = 10000
EVAL_EPISODE = 100 EVAL_EPISODE = 50
NUM_ACTIONS = None NUM_ACTIONS = None
ROM_FILE = None ROM_FILE = None
...@@ -63,10 +57,10 @@ def get_player(viz=False, train=False): ...@@ -63,10 +57,10 @@ def get_player(viz=False, train=False):
live_lost_as_eoe=train) live_lost_as_eoe=train)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = pl.get_num_actions() NUM_ACTIONS = pl.get_num_actions()
if not train: if not train:
pl = HistoryFramePlayer(pl, FRAME_HISTORY) pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1) pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 20000)
return pl return pl
class Model(ModelDesc): class Model(ModelDesc):
...@@ -81,7 +75,7 @@ class Model(ModelDesc): ...@@ -81,7 +75,7 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image, is_training): def _get_DQN_prediction(self, image, is_training):
""" image: [0,255]""" """ image: [0,255]"""
image = image / 255.0 image = image / 255.0
with argscope(Conv2D, nl=tf.nn.relu, use_bias=True): with argscope(Conv2D, nl=PReLU.f, use_bias=True):
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5, stride=1) l = Conv2D('conv0', image, out_channel=32, kernel_shape=5, stride=1)
l = MaxPooling('pool0', l, 2) l = MaxPooling('pool0', l, 2)
l = Conv2D('conv1', l, out_channel=32, kernel_shape=5, stride=1) l = Conv2D('conv1', l, out_channel=32, kernel_shape=5, stride=1)
...@@ -158,7 +152,11 @@ def play_one_episode(player, func, verbose=False): ...@@ -158,7 +152,11 @@ def play_one_episode(player, func, verbose=False):
return np.mean(player.play_one_episode(f)) return np.mean(player.play_one_episode(f))
def play_model(model_path): def play_model(model_path):
player = get_player(0.013) import uuid
dirname = 'record' + str(uuid.uuid1())[:6]
print dirname
os.mkdir(dirname)
player = get_player(viz=dirname)
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
input_data_mapping=[0], input_data_mapping=[0],
...@@ -168,8 +166,9 @@ def play_model(model_path): ...@@ -168,8 +166,9 @@ def play_model(model_path):
while True: while True:
score = play_one_episode(player, predfunc) score = play_one_episode(player, predfunc)
print("Total:", score) print("Total:", score)
break
def eval_with_funcs(predict_funcs): def eval_with_funcs(predict_funcs, nr_eval=EVAL_EPISODE):
class Worker(StoppableThread): class Worker(StoppableThread):
def __init__(self, func, queue): def __init__(self, func, queue):
super(Worker, self).__init__() super(Worker, self).__init__()
...@@ -181,7 +180,7 @@ def eval_with_funcs(predict_funcs): ...@@ -181,7 +180,7 @@ def eval_with_funcs(predict_funcs):
score = play_one_episode(player, self.func) score = play_one_episode(player, self.func)
self.queue_put_stoppable(self.q, score) self.queue_put_stoppable(self.q, score)
q = queue.Queue(maxsize=3) q = queue.Queue(maxsize=2)
threads = [Worker(f, q) for f in predict_funcs] threads = [Worker(f, q) for f in predict_funcs]
for k in threads: for k in threads:
...@@ -189,10 +188,11 @@ def eval_with_funcs(predict_funcs): ...@@ -189,10 +188,11 @@ def eval_with_funcs(predict_funcs):
time.sleep(0.1) # avoid simulator bugs time.sleep(0.1) # avoid simulator bugs
stat = StatCounter() stat = StatCounter()
try: try:
for _ in tqdm(range(EVAL_EPISODE)): for _ in tqdm(range(nr_eval)):
r = q.get() r = q.get()
stat.feed(r) stat.feed(r)
finally: finally:
logger.info("Waiting for all the workers to finish the last run...")
for k in threads: k.stop() for k in threads: k.stop()
for k in threads: k.join() for k in threads: k.join()
return (stat.average, stat.max) return (stat.average, stat.max)
...@@ -214,9 +214,14 @@ class Evaluator(Callback): ...@@ -214,9 +214,14 @@ class Evaluator(Callback):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func( self.pred_funcs = [self.trainer.get_predict_func(
['state'], ['fct/output'])] * NR_PROC ['state'], ['fct/output'])] * NR_PROC
self.eval_episode = EVAL_EPISODE
def _trigger_epoch(self): def _trigger_epoch(self):
mean, max = eval_with_funcs(self.pred_funcs) t = time.time()
mean, max = eval_with_funcs(self.pred_funcs, nr_eval=self.eval_episode)
t = time.time() - t
if t > 8 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.89)
self.trainer.write_scalar_summary('mean_score', mean) self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', max) self.trainer.write_scalar_summary('max_score', max)
...@@ -240,7 +245,7 @@ def get_config(): ...@@ -240,7 +245,7 @@ def get_config():
reward_clip=(-1, 1), reward_clip=(-1, 1),
history_len=FRAME_HISTORY) history_len=FRAME_HISTORY)
lr = tf.Variable(0.00025, trainable=False, name='learning_rate') lr = tf.Variable(0.0004, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
......
Implement DQN in:
**Human-level Control Through Deep Reinforcement Learning**
and Double-DQN in:
**Deep Reinforcement Learning with Double Q-learning**
To run:
```
./DQN.py --rom breakout.rom --gpu 0
```
A demo trained with Double-DQN is available at [youtube](https://youtu.be/o21mddZtE5Y)
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np import numpy as np
import time import time, os
import os
import cv2 import cv2
from collections import deque from collections import deque
import six
from six.moves import range from six.moves import range
from ..utils import get_rng, logger, memoized from ..utils import get_rng, logger, memoized
from ..utils.stat import StatCounter from ..utils.stat import StatCounter
...@@ -37,7 +37,10 @@ class AtariPlayer(RLEnvironment): ...@@ -37,7 +37,10 @@ class AtariPlayer(RLEnvironment):
:param frame_skip: skip every k frames and repeat the action :param frame_skip: skip every k frames and repeat the action
:param image_shape: (w, h) :param image_shape: (w, h)
:param height_range: (h1, h2) to cut :param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable :param viz: visualization to be done.
Set to 0 to disable.
Set to a positive number to be the delay between frames to show.
Set to a string to be a directory to store frames.
:param nullop_start: start with random number of null ops :param nullop_start: start with random number of null ops
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training. :param live_losts_as_eoe: consider lost of lives as end of episode. useful for training.
""" """
...@@ -57,18 +60,24 @@ class AtariPlayer(RLEnvironment): ...@@ -57,18 +60,24 @@ class AtariPlayer(RLEnvironment):
self.ale.setBool('color_averaging', False) self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check # manual.pdf suggests otherwise. may need to check
self.ale.setFloat('repeat_action_probability', 0.0) self.ale.setFloat('repeat_action_probability', 0.0)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
# viz setup
if isinstance(viz, six.string_types):
assert os.path.isdir(viz), viz
self.ale.setString('record_screen_dir', viz)
viz = 0
if isinstance(viz, int): if isinstance(viz, int):
viz = float(viz) viz = float(viz)
self.viz = viz self.viz = viz
self.romname = os.path.basename(rom_file)
if self.viz and isinstance(self.viz, float): if self.viz and isinstance(self.viz, float):
self.windowname = os.path.basename(rom_file)
cv2.startWindowThread() cv2.startWindowThread()
cv2.namedWindow(self.romname) cv2.namedWindow(self.windowname)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
self.live_lost_as_eoe = live_lost_as_eoe self.live_lost_as_eoe = live_lost_as_eoe
self.frame_skip = frame_skip self.frame_skip = frame_skip
...@@ -95,7 +104,7 @@ class AtariPlayer(RLEnvironment): ...@@ -95,7 +104,7 @@ class AtariPlayer(RLEnvironment):
ret = np.maximum(ret, self.last_raw_screen) ret = np.maximum(ret, self.last_raw_screen)
if self.viz: if self.viz:
if isinstance(self.viz, float): if isinstance(self.viz, float):
cv2.imshow(self.romname, ret) cv2.imshow(self.windowname, ret)
time.sleep(self.viz) time.sleep(self.viz)
ret = ret[self.height_range[0]:self.height_range[1],:] ret = ret[self.height_range[0]:self.height_range[1],:]
# 0.299,0.587.0.114. same as rgb2y in torch/image # 0.299,0.587.0.114. same as rgb2y in torch/image
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment