Commit ada058f3 authored by Yuxin Wu's avatar Yuxin Wu

update docs in DQN, remove a 4-frame hard-coded assumption (#268)

parent c3effc3c
...@@ -32,7 +32,6 @@ IMAGE_SIZE = (84, 84) ...@@ -32,7 +32,6 @@ IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 4 ACTION_REPEAT = 4
CHANNEL = FRAME_HISTORY
GAMMA = 0.99 GAMMA = 0.99
INIT_EXPLORATION = 1 INIT_EXPLORATION = 1
...@@ -54,8 +53,11 @@ def get_player(viz=False, train=False): ...@@ -54,8 +53,11 @@ def get_player(viz=False, train=False):
pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT,
image_shape=IMAGE_SIZE[::-1], viz=viz, live_lost_as_eoe=train) image_shape=IMAGE_SIZE[::-1], viz=viz, live_lost_as_eoe=train)
if not train: if not train:
# create a new axis to stack history on
pl = MapPlayerState(pl, lambda im: im[:, :, np.newaxis]) pl = MapPlayerState(pl, lambda im: im[:, :, np.newaxis])
# in training, history is taken care of in expreplay buffer
pl = HistoryFramePlayer(pl, FRAME_HISTORY) pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1) pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 30000) pl = LimitLengthPlayer(pl, 30000)
return pl return pl
...@@ -63,7 +65,7 @@ def get_player(viz=False, train=False): ...@@ -63,7 +65,7 @@ def get_player(viz=False, train=False):
class Model(DQNModel): class Model(DQNModel):
def __init__(self): def __init__(self):
super(Model, self).__init__(IMAGE_SIZE, CHANNEL, METHOD, NUM_ACTIONS, GAMMA) super(Model, self).__init__(IMAGE_SIZE, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA)
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
""" image: [0,255]""" """ image: [0,255]"""
......
...@@ -21,8 +21,8 @@ class Model(ModelDesc): ...@@ -21,8 +21,8 @@ class Model(ModelDesc):
self.gamma = gamma self.gamma = gamma
def _get_inputs(self): def _get_inputs(self):
# use a combined state, where the first channels are the current state, # Use a combined state for efficiency.
# and the last 4 channels are the next state # The first h channels are the current state, and the last h channels are the next state.
return [InputDesc(tf.uint8, return [InputDesc(tf.uint8,
(None,) + self.image_shape + (self.channel + 1,), (None,) + self.image_shape + (self.channel + 1,),
'comb_state'), 'comb_state'),
...@@ -37,13 +37,13 @@ class Model(ModelDesc): ...@@ -37,13 +37,13 @@ class Model(ModelDesc):
def _build_graph(self, inputs): def _build_graph(self, inputs):
comb_state, action, reward, isOver = inputs comb_state, action, reward, isOver = inputs
comb_state = tf.cast(comb_state, tf.float32) comb_state = tf.cast(comb_state, tf.float32)
state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, 4], name='state') state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, self.channel], name='state')
self.predict_value = self._get_DQN_prediction(state) self.predict_value = self._get_DQN_prediction(state)
if not get_current_tower_context().is_training: if not get_current_tower_context().is_training:
return return
reward = tf.clip_by_value(reward, -1, 1) reward = tf.clip_by_value(reward, -1, 1)
next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, 4], name='next_state') next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, self.channel], name='next_state')
action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0) action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N, pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N,
......
...@@ -113,7 +113,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -113,7 +113,7 @@ class ExpReplay(DataFlow, Callback):
This implementation provides the interface as a :class:`DataFlow`. This implementation provides the interface as a :class:`DataFlow`.
This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching). This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
This implementation only works with Q-learning. It assumes that state is This implementation assumes that state is
batch-able, and the network takes batched inputs. batch-able, and the network takes batched inputs.
""" """
...@@ -171,6 +171,18 @@ class ExpReplay(DataFlow, Callback): ...@@ -171,6 +171,18 @@ class ExpReplay(DataFlow, Callback):
pbar.update() pbar.update()
self._init_memory_flag.set() self._init_memory_flag.set()
# quickly fill the memory for debug
def _fake_init_memory(self):
from copy import deepcopy
with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < 5:
self._populate_exp()
pbar.update()
while len(self.mem) < self.init_memory_size:
self.mem.append(deepcopy(self.mem._hist[0]))
pbar.update()
self._init_memory_flag.set()
def _populate_exp(self): def _populate_exp(self):
""" populate a transition by epsilon-greedy""" """ populate a transition by epsilon-greedy"""
old_s = self.player.current_state() old_s = self.player.current_state()
...@@ -188,7 +200,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -188,7 +200,7 @@ class ExpReplay(DataFlow, Callback):
reward, isOver = self.player.action(act) reward, isOver = self.player.action(act)
self.mem.append(Experience(old_s, act, reward, isOver)) self.mem.append(Experience(old_s, act, reward, isOver))
def debug_sample(self, sample): def _debug_sample(self, sample):
import cv2 import cv2
def view_state(comb_state): def view_state(comb_state):
......
...@@ -12,7 +12,7 @@ __all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32'] ...@@ -12,7 +12,7 @@ __all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']
class ColorSpace(ImageAugmentor): class ColorSpace(ImageAugmentor):
""" Convert into another colorspace. """ """ Convert into another colorspace. """
def __init__(self, mode=cv2.COLOR_BGR2GRAY, keepdims=True): def __init__(self, mode, keepdims=True):
""" """
Args: Args:
mode: opencv colorspace conversion code (e.g., `cv2.COLOR_BGR2HSV`) mode: opencv colorspace conversion code (e.g., `cv2.COLOR_BGR2HSV`)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment