Commit 80a110e2 authored by Yuxin Wu's avatar Yuxin Wu

[DQN] Let state have channels.

parent f227f45f
...@@ -278,9 +278,8 @@ if __name__ == '__main__': ...@@ -278,9 +278,8 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
ENV_NAME = args.env ENV_NAME = args.env
logger.info("Environment Name: {}".format(ENV_NAME))
NUM_ACTIONS = get_player().action_space.n NUM_ACTIONS = get_player().action_space.n
logger.info("Number of actions: {}".format(NUM_ACTIONS)) logger.info("Environment: {}, number of actions: {}".format(ENV_NAME, NUM_ACTIONS))
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import os import os
import argparse import argparse
import cv2 import cv2
import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -40,7 +41,7 @@ def get_player(viz=False, train=False): ...@@ -40,7 +41,7 @@ def get_player(viz=False, train=False):
env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz, env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz,
live_lost_as_eoe=train, max_num_frames=60000) live_lost_as_eoe=train, max_num_frames=60000)
env = FireResetEnv(env) env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)) env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)[:, :, np.newaxis])
if not train: if not train:
# in training, history is taken care of in expreplay buffer # in training, history is taken care of in expreplay buffer
env = FrameStack(env, FRAME_HISTORY) env = FrameStack(env, FRAME_HISTORY)
...@@ -49,10 +50,10 @@ def get_player(viz=False, train=False): ...@@ -49,10 +50,10 @@ def get_player(viz=False, train=False):
class Model(DQNModel): class Model(DQNModel):
def __init__(self): def __init__(self):
super(Model, self).__init__(IMAGE_SIZE, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA) super(Model, self).__init__(IMAGE_SIZE, 1, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA)
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
""" image: [0,255]""" """ image: [N, H, W, C * history] in [0,255]"""
image = image / 255.0 image = image / 255.0
with argscope(Conv2D, activation=lambda x: PReLU('prelu', x), use_bias=True): with argscope(Conv2D, activation=lambda x: PReLU('prelu', x), use_bias=True):
l = (LinearWrap(image) l = (LinearWrap(image)
...@@ -86,7 +87,7 @@ def get_config(): ...@@ -86,7 +87,7 @@ def get_config():
expreplay = ExpReplay( expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']), predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True), player=get_player(train=True),
state_shape=IMAGE_SIZE, state_shape=IMAGE_SIZE + (1,),
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE, init_memory_size=INIT_MEMORY_SIZE,
......
...@@ -14,18 +14,24 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope ...@@ -14,18 +14,24 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
class Model(ModelDesc): class Model(ModelDesc):
learning_rate = 1e-3 learning_rate = 1e-3
def __init__(self, image_shape, channel, method, num_actions, gamma): def __init__(self, image_shape, channel, history, method, num_actions, gamma):
self.image_shape = image_shape
self.channel = channel self.channel = channel
self._shape2d = image_shape
self._shape3d = image_shape + (channel, )
self._shape4d_for_prediction = (-1, ) + image_shape + (channel * history, )
self._channel = channel
self.history = history
self.method = method self.method = method
self.num_actions = num_actions self.num_actions = num_actions
self.gamma = gamma self.gamma = gamma
def inputs(self): def inputs(self):
# Use a combined state for efficiency. # When we use h history frames, the current state and the next state will have (h-1) overlapping frames.
# The first h channels are the current state, and the last h channels are the next state. # Therefore we use a combined state for efficiency:
# The first h are the current state, and the last h are the next state.
return [tf.placeholder(tf.uint8, return [tf.placeholder(tf.uint8,
(None,) + self.image_shape + (self.channel + 1,), (None,) + self._shape2d +
(self._channel * (self.history + 1),),
'comb_state'), 'comb_state'),
tf.placeholder(tf.int64, (None,), 'action'), tf.placeholder(tf.int64, (None,), 'action'),
tf.placeholder(tf.float32, (None,), 'reward'), tf.placeholder(tf.float32, (None,), 'reward'),
...@@ -35,20 +41,23 @@ class Model(ModelDesc): ...@@ -35,20 +41,23 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
pass pass
# decorate the function
@auto_reuse_variable_scope @auto_reuse_variable_scope
def get_DQN_prediction(self, image): def get_DQN_prediction(self, image):
return self._get_DQN_prediction(image) return self._get_DQN_prediction(image)
def build_graph(self, comb_state, action, reward, isOver): def build_graph(self, comb_state, action, reward, isOver):
comb_state = tf.cast(comb_state, tf.float32) comb_state = tf.cast(comb_state, tf.float32)
state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, self.channel], name='state') comb_state = tf.reshape(comb_state, [-1] + list(self._shape3d) + [self.history + 1])
state = tf.slice(comb_state, [0, 0, 0, 0, 0], [-1, -1, -1, -1, self.history])
state = tf.reshape(state, self._shape4d_for_prediction, name='state')
self.predict_value = self.get_DQN_prediction(state) self.predict_value = self.get_DQN_prediction(state)
if not get_current_tower_context().is_training: if not get_current_tower_context().is_training:
return return
reward = tf.clip_by_value(reward, -1, 1) reward = tf.clip_by_value(reward, -1, 1)
next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, self.channel], name='next_state') next_state = tf.slice(comb_state, [0, 0, 0, 0, 1], [-1, -1, -1, -1, self.history], name='next_state')
next_state = tf.reshape(next_state, self._shape4d_for_prediction)
action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0) action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N, pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N,
......
...@@ -20,10 +20,9 @@ Claimed performance in the paper can be reproduced, on several games I've tested ...@@ -20,10 +20,9 @@ Claimed performance in the paper can be reproduced, on several games I've tested
![DQN](curve-breakout.png) ![DQN](curve-breakout.png)
On one TitanX, Double-DQN took 1 day of training to reach a score of 400 on breakout. On one (Maxwell) TitanX, Double-DQN took ~18 hours of training to reach a score of 400 on breakout.
Batch-A3C implementation only took <2 hours.
Double-DQN with nature paper setting runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on (Maxwell) TitanX. Double-DQN with nature paper setting runs at 60 batches (3840 trained frames, 240 seen frames, 960 game frames) per second on TitanX.
## How to use ## How to use
......
...@@ -25,6 +25,9 @@ class ReplayMemory(object): ...@@ -25,6 +25,9 @@ class ReplayMemory(object):
def __init__(self, max_size, state_shape, history_len): def __init__(self, max_size, state_shape, history_len):
self.max_size = int(max_size) self.max_size = int(max_size)
self.state_shape = state_shape self.state_shape = state_shape
self._state_transpose = list(range(1, len(state_shape) + 1)) + [0]
self._channel = state_shape[2] if len(state_shape) == 3 else 1
self._shape3d = (state_shape[0], state_shape[1], self._channel * (history_len + 1))
self.history_len = int(history_len) self.history_len = int(history_len)
self.state = np.zeros((self.max_size,) + state_shape, dtype='uint8') self.state = np.zeros((self.max_size,) + state_shape, dtype='uint8')
...@@ -62,7 +65,7 @@ class ReplayMemory(object): ...@@ -62,7 +65,7 @@ class ReplayMemory(object):
def sample(self, idx): def sample(self, idx):
""" return a tuple of (s,r,a,o), """ return a tuple of (s,r,a,o),
where s is of shape STATE_SIZE + (hist_len+1,)""" where s is of shape [H, W, channel * (hist_len+1)]"""
idx = (self._curr_pos + idx) % self._curr_size idx = (self._curr_pos + idx) % self._curr_size
k = self.history_len + 1 k = self.history_len + 1
if idx + k <= self._curr_size: if idx + k <= self._curr_size:
...@@ -86,7 +89,9 @@ class ReplayMemory(object): ...@@ -86,7 +89,9 @@ class ReplayMemory(object):
state = copy.deepcopy(state) state = copy.deepcopy(state)
state[:k + 1].fill(0) state[:k + 1].fill(0)
break break
state = state.transpose(1, 2, 0) # move the first dim to the last
state = state.transpose(*self._state_transpose)
state = state.reshape(self._shape3d)
return (state, reward[-2], action[-2], isOver[-2]) return (state, reward[-2], action[-2], isOver[-2])
def _slice(self, arr, start, end): def _slice(self, arr, start, end):
...@@ -130,11 +135,13 @@ class ExpReplay(DataFlow, Callback): ...@@ -130,11 +135,13 @@ class ExpReplay(DataFlow, Callback):
predictor_io_names (tuple of list of str): input/output names to predictor_io_names (tuple of list of str): input/output names to
predict Q value from state. predict Q value from state.
player (RLEnvironment): the player. player (RLEnvironment): the player.
state_shape (tuple): h, w, c
history_len (int): length of history frames to concat. Zero-filled history_len (int): length of history frames to concat. Zero-filled
initial frames. initial frames.
update_frequency (int): number of new transitions to add to memory update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training. after sampling a batch of transitions for training.
""" """
assert len(state_shape) == 3, state_shape
init_memory_size = int(init_memory_size) init_memory_size = int(init_memory_size)
for k, v in locals().items(): for k, v in locals().items():
...@@ -195,10 +202,10 @@ class ExpReplay(DataFlow, Callback): ...@@ -195,10 +202,10 @@ class ExpReplay(DataFlow, Callback):
# build a history state # build a history state
history = self.mem.recent_state() history = self.mem.recent_state()
history.append(old_s) history.append(old_s)
history = np.stack(history, axis=2) history = np.concatenate(history, axis=-1)
# assume batched network # assume batched network
q_values = self.predictor(history[None, :, :, :])[0][0] # this is the bottleneck q_values = self.predictor(np.expand_dims(history, 0))[0][0] # this is the bottleneck
act = np.argmax(q_values) act = np.argmax(q_values)
self._current_ob, reward, isOver, info = self.player.step(act) self._current_ob, reward, isOver, info = self.player.step(act)
self._current_game_score.feed(reward) self._current_game_score.feed(reward)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment