Commit 0b561b3b authored by Yuxin Wu's avatar Yuxin Wu

Make DQN support states with more dimensions

parent 87fad54b
...@@ -34,8 +34,7 @@ else: ...@@ -34,8 +34,7 @@ else:
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
GAMMA = 0.99 GAMMA = 0.99
CHANNEL = FRAME_HISTORY * 3 STATE_SHAPE = IMAGE_SIZE + (3, )
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
LOCAL_TIME_MAX = 5 LOCAL_TIME_MAX = 5
STEPS_PER_EPOCH = 6000 STEPS_PER_EPOCH = 6000
...@@ -70,13 +69,17 @@ class MySimulatorWorker(SimulatorProcess): ...@@ -70,13 +69,17 @@ class MySimulatorWorker(SimulatorProcess):
class Model(ModelDesc): class Model(ModelDesc):
def inputs(self): def inputs(self):
assert NUM_ACTIONS is not None assert NUM_ACTIONS is not None
return [tf.placeholder(tf.uint8, (None,) + IMAGE_SHAPE3, 'state'), return [tf.placeholder(tf.uint8, (None,) + STATE_SHAPE + (FRAME_HISTORY, ), 'state'),
tf.placeholder(tf.int64, (None,), 'action'), tf.placeholder(tf.int64, (None,), 'action'),
tf.placeholder(tf.float32, (None,), 'futurereward'), tf.placeholder(tf.float32, (None,), 'futurereward'),
tf.placeholder(tf.float32, (None,), 'action_prob'), tf.placeholder(tf.float32, (None,), 'action_prob'),
] ]
def _get_NN_prediction(self, image): def _get_NN_prediction(self, state):
assert state.shape.rank == 5 # Batch, H, W, Channel, History
state = tf.transpose(state, [0, 1, 2, 4, 3]) # swap channel & history, to be compatible with old models
image = tf.reshape(state, [-1] + list(STATE_SHAPE[:2]) + [STATE_SHAPE[2] * FRAME_HISTORY])
image = tf.cast(image, tf.float32) / 255.0 image = tf.cast(image, tf.float32) / 255.0
with argscope(Conv2D, activation=tf.nn.relu): with argscope(Conv2D, activation=tf.nn.relu):
l = Conv2D('conv0', image, 32, 5) l = Conv2D('conv0', image, 32, 5)
......
...@@ -19,7 +19,7 @@ from expreplay import ExpReplay ...@@ -19,7 +19,7 @@ from expreplay import ExpReplay
BATCH_SIZE = 64 BATCH_SIZE = 64
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
IMAGE_CHANNEL = None # 3 in gym and 1 in our own wrapper STATE_SHAPE = None # IMAGE_SIZE + (3,) in gym, and IMAGE_SIZE in ALE
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 4 # aka FRAME_SKIP ACTION_REPEAT = 4 # aka FRAME_SKIP
UPDATE_FREQ = 4 UPDATE_FREQ = 4
...@@ -39,8 +39,7 @@ METHOD = None ...@@ -39,8 +39,7 @@ METHOD = None
def resize_keepdims(im, size): def resize_keepdims(im, size):
# Opencv's resize remove the extra dimension for grayscale images. # Opencv's resize remove the extra dimension for grayscale images. We add it back.
# We add it back.
ret = cv2.resize(im, size) ret = cv2.resize(im, size)
if im.ndim == 3 and ret.ndim == 2: if im.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis] ret = ret[:, :, np.newaxis]
...@@ -65,10 +64,20 @@ def get_player(viz=False, train=False): ...@@ -65,10 +64,20 @@ def get_player(viz=False, train=False):
class Model(DQNModel): class Model(DQNModel):
"""
A DQN model for 2D/3D (image) observations.
"""
def __init__(self): def __init__(self):
super(Model, self).__init__(IMAGE_SIZE, IMAGE_CHANNEL, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA) assert len(STATE_SHAPE) in [2, 3]
super(Model, self).__init__(STATE_SHAPE, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA)
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
assert image.shape.rank in [4, 5], image.shape
# image: N, H, W, (C), Hist
if image.shape.rank == 5:
# merge C & Hist
image = tf.reshape(image, [-1] + list(STATE_SHAPE[:2]) + [STATE_SHAPE[2] * FRAME_HISTORY])
image = image / 255.0 image = image / 255.0
with argscope(Conv2D, activation=lambda x: PReLU('prelu', x), use_bias=True): with argscope(Conv2D, activation=lambda x: PReLU('prelu', x), use_bias=True):
l = (LinearWrap(image) l = (LinearWrap(image)
...@@ -102,7 +111,7 @@ def get_config(): ...@@ -102,7 +111,7 @@ def get_config():
expreplay = ExpReplay( expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']), predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True), player=get_player(train=True),
state_shape=IMAGE_SIZE + (IMAGE_CHANNEL,), state_shape=STATE_SHAPE,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE, init_memory_size=INIT_MEMORY_SIZE,
...@@ -152,7 +161,7 @@ if __name__ == '__main__': ...@@ -152,7 +161,7 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
ENV_NAME = args.env ENV_NAME = args.env
USE_GYM = not ENV_NAME.endswith('.bin') USE_GYM = not ENV_NAME.endswith('.bin')
IMAGE_CHANNEL = 3 if USE_GYM else 1 STATE_SHAPE = IMAGE_SIZE + (3, ) if USE_GYM else IMAGE_SIZE
METHOD = args.algo METHOD = args.algo
# set num_actions # set num_actions
NUM_ACTIONS = get_player().action_space.n NUM_ACTIONS = get_player().action_space.n
......
...@@ -12,16 +12,19 @@ from tensorpack.utils import logger ...@@ -12,16 +12,19 @@ from tensorpack.utils import logger
class Model(ModelDesc): class Model(ModelDesc):
learning_rate = 1e-3 learning_rate = 1e-3
def __init__(self, image_shape, channel, history, method, num_actions, gamma): state_dtype = tf.uint8
assert len(image_shape) == 2, image_shape
self.channel = channel def __init__(self, state_shape, history, method, num_actions, gamma):
self._shape2d = tuple(image_shape) """
self._shape3d = self._shape2d + (channel, ) Args:
self._shape4d_for_prediction = (-1, ) + self._shape2d + (history * channel, ) state_shape (tuple[int]),
self._channel = channel history (int):
"""
self._state_shape = tuple(state_shape)
self._stacked_state_shape = (-1, ) + self._state_shape + (history, )
self.history = history self.history = history
self.method = method self.method = method
self.num_actions = num_actions self.num_actions = num_actions
...@@ -31,37 +34,43 @@ class Model(ModelDesc): ...@@ -31,37 +34,43 @@ class Model(ModelDesc):
# When we use h history frames, the current state and the next state will have (h-1) overlapping frames. # When we use h history frames, the current state and the next state will have (h-1) overlapping frames.
# Therefore we use a combined state for efficiency: # Therefore we use a combined state for efficiency:
# The first h are the current state, and the last h are the next state. # The first h are the current state, and the last h are the next state.
return [tf.placeholder(tf.uint8, return [tf.placeholder(self.state_dtype,
(None,) + self._shape2d + (None,) + self._state_shape + (self.history + 1, ),
((self.history + 1) * self.channel,),
'comb_state'), 'comb_state'),
tf.placeholder(tf.int64, (None,), 'action'), tf.placeholder(tf.int64, (None,), 'action'),
tf.placeholder(tf.float32, (None,), 'reward'), tf.placeholder(tf.float32, (None,), 'reward'),
tf.placeholder(tf.bool, (None,), 'isOver')] tf.placeholder(tf.bool, (None,), 'isOver')]
@abc.abstractmethod @abc.abstractmethod
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, state):
"""
state: N + state_shape + history
"""
pass pass
@auto_reuse_variable_scope @auto_reuse_variable_scope
def get_DQN_prediction(self, image): def get_DQN_prediction(self, state):
""" image: [N, H, W, history * C] in [0,255]""" return self._get_DQN_prediction(state)
return self._get_DQN_prediction(image)
def build_graph(self, comb_state, action, reward, isOver): def build_graph(self, comb_state, action, reward, isOver):
comb_state = tf.cast(comb_state, tf.float32) comb_state = tf.cast(comb_state, tf.float32)
comb_state = tf.reshape( input_rank = comb_state.shape.rank
comb_state, [-1] + list(self._shape2d) + [self.history + 1, self.channel])
state = tf.slice(
comb_state,
[0] * input_rank,
[-1] * (input_rank - 1) + [self.history], name='state')
state = tf.slice(comb_state, [0, 0, 0, 0, 0], [-1, -1, -1, self.history, -1])
state = tf.reshape(state, self._shape4d_for_prediction, name='state')
self.predict_value = self.get_DQN_prediction(state) self.predict_value = self.get_DQN_prediction(state)
if not get_current_tower_context().is_training: if not get_current_tower_context().is_training:
return return
reward = tf.clip_by_value(reward, -1, 1) reward = tf.clip_by_value(reward, -1, 1)
next_state = tf.slice(comb_state, [0, 0, 0, 1, 0], [-1, -1, -1, self.history, -1], name='next_state') next_state = tf.slice(
next_state = tf.reshape(next_state, self._shape4d_for_prediction) comb_state,
[0] * (input_rank - 1) + [1],
[-1] * (input_rank - 1) + [self.history], name='next_state')
next_state = tf.reshape(next_state, self._stacked_state_shape)
action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0) action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N, pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N,
......
...@@ -94,7 +94,7 @@ class AtariPlayer(gym.Env): ...@@ -94,7 +94,7 @@ class AtariPlayer(gym.Env):
self.action_space = spaces.Discrete(len(self.actions)) self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width, 1), dtype=np.uint8) low=0, high=255, shape=(self.height, self.width), dtype=np.uint8)
self._restart_episode() self._restart_episode()
def get_action_meanings(self): def get_action_meanings(self):
...@@ -109,7 +109,7 @@ class AtariPlayer(gym.Env): ...@@ -109,7 +109,7 @@ class AtariPlayer(gym.Env):
def _current_state(self): def _current_state(self):
""" """
:returns: a gray-scale (h, w, 1) uint8 image :returns: a gray-scale (h, w) uint8 image
""" """
ret = self._grab_raw_image() ret = self._grab_raw_image()
# max-pooled over the last screen # max-pooled over the last screen
...@@ -120,7 +120,7 @@ class AtariPlayer(gym.Env): ...@@ -120,7 +120,7 @@ class AtariPlayer(gym.Env):
cv2.waitKey(int(self.viz * 1000)) cv2.waitKey(int(self.viz * 1000))
ret = ret.astype('float32') ret = ret.astype('float32')
# 0.299,0.587.0.114. same as rgb2y in torch/image # 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis] ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :]
return ret.astype('uint8') # to save some memory return ret.astype('uint8') # to save some memory
def _restart_episode(self): def _restart_episode(self):
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import numpy as np import numpy as np
from collections import deque from collections import deque
import gym import gym
from gym import spaces
_v0, _v1 = gym.__version__.split('.')[:2] _v0, _v1 = gym.__version__.split('.')[:2]
assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__ assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__
...@@ -27,17 +26,13 @@ class MapState(gym.ObservationWrapper): ...@@ -27,17 +26,13 @@ class MapState(gym.ObservationWrapper):
class FrameStack(gym.Wrapper): class FrameStack(gym.Wrapper):
""" """
Buffer observations and stack across channels (last axis). Buffer consecutive k observations and stack them on a new last axis.
The output observation has shape (H, W, History * Channel) The output observation has shape `original_shape + (k, )`.
""" """
def __init__(self, env, k): def __init__(self, env, k):
gym.Wrapper.__init__(self, env) gym.Wrapper.__init__(self, env)
self.k = k self.k = k
self.frames = deque([], maxlen=k) self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], chan * k), dtype=np.uint8)
def reset(self): def reset(self):
"""Clear buffer and re-fill by duplicating the first observation.""" """Clear buffer and re-fill by duplicating the first observation."""
...@@ -54,10 +49,7 @@ class FrameStack(gym.Wrapper): ...@@ -54,10 +49,7 @@ class FrameStack(gym.Wrapper):
def observation(self): def observation(self):
assert len(self.frames) == self.k assert len(self.frames) == self.k
if self.frames[-1].ndim == 2:
return np.stack(self.frames, axis=-1) return np.stack(self.frames, axis=-1)
else:
return np.concatenate(self.frames, axis=2)
class _FireResetEnv(gym.Wrapper): class _FireResetEnv(gym.Wrapper):
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: common.py # File: common.py
# Author: Yuxin Wu # Author: Yuxin Wu
import multiprocessing import multiprocessing
import numpy as np
import random import random
import time import time
from six.moves import queue from six.moves import queue
...@@ -19,7 +21,8 @@ def play_one_episode(env, func, render=False): ...@@ -19,7 +21,8 @@ def play_one_episode(env, func, render=False):
""" """
Map from observation to action, with 0.01 greedy. Map from observation to action, with 0.01 greedy.
""" """
act = func(s[None, :, :, :])[0][0].argmax() s = np.expand_dims(s, 0) # batch
act = func(s)[0][0].argmax()
if random.random() < 0.01: if random.random() < 0.01:
spc = env.action_space spc = env.action_space
act = spc.sample() act = spc.sample()
......
...@@ -23,19 +23,21 @@ Experience = namedtuple('Experience', ...@@ -23,19 +23,21 @@ Experience = namedtuple('Experience',
class ReplayMemory(object): class ReplayMemory(object):
def __init__(self, max_size, state_shape, history_len): def __init__(self, max_size, state_shape, history_len):
"""
Args:
state_shape (tuple[int]): shape (without history) of state
"""
self.max_size = int(max_size) self.max_size = int(max_size)
self.state_shape = state_shape self.state_shape = state_shape
assert len(state_shape) == 3, state_shape assert len(state_shape) in [1, 2, 3], state_shape
# self._state_transpose = list(range(1, len(state_shape) + 1)) + [0] self._output_shape = self.state_shape + (history_len + 1, )
self._channel = state_shape[2] if len(state_shape) == 3 else 1
self._shape3d = (state_shape[0], state_shape[1], self._channel * (history_len + 1))
self.history_len = int(history_len) self.history_len = int(history_len)
state_shape = (self.max_size,) + state_shape all_state_shape = (self.max_size,) + state_shape
logger.info("Creating experience replay buffer of {:.1f} GB ... " logger.info("Creating experience replay buffer of {:.1f} GB ... "
"use a smaller buffer if you don't have enough CPU memory.".format( "use a smaller buffer if you don't have enough CPU memory.".format(
np.prod(state_shape) / 1024.0**3)) np.prod(all_state_shape) / 1024.0**3))
self.state = np.zeros(state_shape, dtype='uint8') self.state = np.zeros(all_state_shape, dtype='uint8')
self.action = np.zeros((self.max_size,), dtype='int32') self.action = np.zeros((self.max_size,), dtype='int32')
self.reward = np.zeros((self.max_size,), dtype='float32') self.reward = np.zeros((self.max_size,), dtype='float32')
self.isOver = np.zeros((self.max_size,), dtype='bool') self.isOver = np.zeros((self.max_size,), dtype='bool')
...@@ -70,7 +72,8 @@ class ReplayMemory(object): ...@@ -70,7 +72,8 @@ class ReplayMemory(object):
def sample(self, idx): def sample(self, idx):
""" return a tuple of (s,r,a,o), """ return a tuple of (s,r,a,o),
where s is of shape [H, W, (hist_len+1) * channel]""" where s is of shape self._output_shape, which is
[H, W, (hist_len+1) * channel] if input is (H, W, channel)"""
idx = (self._curr_pos + idx) % self._curr_size idx = (self._curr_pos + idx) % self._curr_size
k = self.history_len + 1 k = self.history_len + 1
if idx + k <= self._curr_size: if idx + k <= self._curr_size:
...@@ -95,8 +98,8 @@ class ReplayMemory(object): ...@@ -95,8 +98,8 @@ class ReplayMemory(object):
state = copy.deepcopy(state) state = copy.deepcopy(state)
state[:k + 1].fill(0) state[:k + 1].fill(0)
break break
# move the first dim to the last # move the first dim (history) to the last
state = state.transpose(1, 2, 0, 3).reshape(self._shape3d) state = np.moveaxis(state, 0, -1)
return (state, reward[-2], action[-2], isOver[-2]) return (state, reward[-2], action[-2], isOver[-2])
def _slice(self, arr, start, end): def _slice(self, arr, start, end):
...@@ -140,13 +143,13 @@ class ExpReplay(DataFlow, Callback): ...@@ -140,13 +143,13 @@ class ExpReplay(DataFlow, Callback):
predictor_io_names (tuple of list of str): input/output names to predictor_io_names (tuple of list of str): input/output names to
predict Q value from state. predict Q value from state.
player (gym.Env): the player. player (gym.Env): the player.
state_shape (tuple): h, w, c state_shape (tuple):
history_len (int): length of history frames to concat. Zero-filled history_len (int): length of history frames to concat. Zero-filled
initial frames. initial frames.
update_frequency (int): number of new transitions to add to memory update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training. after sampling a batch of transitions for training.
""" """
assert len(state_shape) == 3, state_shape assert len(state_shape) in [1, 2, 3], state_shape
init_memory_size = int(init_memory_size) init_memory_size = int(init_memory_size)
for k, v in locals().items(): for k, v in locals().items():
...@@ -207,7 +210,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -207,7 +210,7 @@ class ExpReplay(DataFlow, Callback):
# build a history state # build a history state
history = self.mem.recent_state() history = self.mem.recent_state()
history.append(old_s) history.append(old_s)
history = np.concatenate(history, axis=-1) # H,W,HistxC history = np.stack(history, axis=-1) # state_shape + (Hist,)
history = np.expand_dims(history, axis=0) history = np.expand_dims(history, axis=0)
# assume batched network # assume batched network
...@@ -216,7 +219,9 @@ class ExpReplay(DataFlow, Callback): ...@@ -216,7 +219,9 @@ class ExpReplay(DataFlow, Callback):
self._current_ob, reward, isOver, info = self.player.step(act) self._current_ob, reward, isOver, info = self.player.step(act)
self._current_game_score.feed(reward) self._current_game_score.feed(reward)
if isOver: if isOver:
if info['ale.lives'] == 0: # only record score when a whole game is over (not when an episode is over) # handle ale-specific information
if info.get('ale.lives', -1) == 0:
# only record score when a whole game is over (not when an episode is over)
self._player_scores.feed(self._current_game_score.sum) self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset() self._current_game_score.reset()
self.player.reset() self.player.reset()
...@@ -226,6 +231,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -226,6 +231,7 @@ class ExpReplay(DataFlow, Callback):
import cv2 import cv2
def view_state(comb_state): def view_state(comb_state):
# this function assumes comb_state is 3D
state = comb_state[:, :, :-1] state = comb_state[:, :, :-1]
next_state = comb_state[:, :, 1:] next_state = comb_state[:, :, 1:]
r = np.concatenate([state[:, :, k] for k in range(self.history_len)], axis=1) r = np.concatenate([state[:, :, k] for k in range(self.history_len)], axis=1)
......
...@@ -44,10 +44,16 @@ class Callback(object): ...@@ -44,10 +44,16 @@ class Callback(object):
_chief_only = True _chief_only = True
name_scope = ""
"""
A name scope for ops created inside this callback.
By default to the name of the class, but can be set per-instance.
"""
def setup_graph(self, trainer): def setup_graph(self, trainer):
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
scope_name = type(self).__name__ scope_name = self.name_scope or type(self).__name__
scope_name = scope_name.replace('_', '') scope_name = scope_name.replace('_', '')
with tf.name_scope(scope_name): with tf.name_scope(scope_name):
self._setup_graph() self._setup_graph()
......
...@@ -251,11 +251,13 @@ class QueueInput(FeedfreeInput): ...@@ -251,11 +251,13 @@ class QueueInput(FeedfreeInput):
# in TF there is no API to get queue capacity, so we can only summary the size # in TF there is no API to get queue capacity, so we can only summary the size
size = tf.cast(self.queue.size(), tf.float32, name='queue_size') size = tf.cast(self.queue.size(), tf.float32, name='queue_size')
size_ema_op = add_moving_summary(size, collection=None, decay=0.5)[0].op size_ema_op = add_moving_summary(size, collection=None, decay=0.5)[0].op
return RunOp( ret = RunOp(
lambda: size_ema_op, lambda: size_ema_op,
run_before=False, run_before=False,
run_as_trigger=False, run_as_trigger=False,
run_step=True) run_step=True)
ret.name_scope = "InputSource/EMA"
return ret
def _get_callbacks(self): def _get_callbacks(self):
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
......
...@@ -194,6 +194,7 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -194,6 +194,7 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
run_before=True, run_before=True,
run_as_trigger=self.BROADCAST_EVERY_EPOCH, run_as_trigger=self.BROADCAST_EVERY_EPOCH,
verbose=True) verbose=True)
cb.name_scope = "SyncVariables"
return [cb] return [cb]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment