Commit 0b561b3b authored by Yuxin Wu's avatar Yuxin Wu

Make DQN support states with more dimensions

parent 87fad54b
......@@ -34,8 +34,7 @@ else:
IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4
GAMMA = 0.99
CHANNEL = FRAME_HISTORY * 3
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
STATE_SHAPE = IMAGE_SIZE + (3, )
LOCAL_TIME_MAX = 5
STEPS_PER_EPOCH = 6000
......@@ -70,13 +69,17 @@ class MySimulatorWorker(SimulatorProcess):
class Model(ModelDesc):
def inputs(self):
assert NUM_ACTIONS is not None
return [tf.placeholder(tf.uint8, (None,) + IMAGE_SHAPE3, 'state'),
return [tf.placeholder(tf.uint8, (None,) + STATE_SHAPE + (FRAME_HISTORY, ), 'state'),
tf.placeholder(tf.int64, (None,), 'action'),
tf.placeholder(tf.float32, (None,), 'futurereward'),
tf.placeholder(tf.float32, (None,), 'action_prob'),
]
def _get_NN_prediction(self, image):
def _get_NN_prediction(self, state):
assert state.shape.rank == 5 # Batch, H, W, Channel, History
state = tf.transpose(state, [0, 1, 2, 4, 3]) # swap channel & history, to be compatible with old models
image = tf.reshape(state, [-1] + list(STATE_SHAPE[:2]) + [STATE_SHAPE[2] * FRAME_HISTORY])
image = tf.cast(image, tf.float32) / 255.0
with argscope(Conv2D, activation=tf.nn.relu):
l = Conv2D('conv0', image, 32, 5)
......
......@@ -19,7 +19,7 @@ from expreplay import ExpReplay
BATCH_SIZE = 64
IMAGE_SIZE = (84, 84)
IMAGE_CHANNEL = None # 3 in gym and 1 in our own wrapper
STATE_SHAPE = None # IMAGE_SIZE + (3,) in gym, and IMAGE_SIZE in ALE
FRAME_HISTORY = 4
ACTION_REPEAT = 4 # aka FRAME_SKIP
UPDATE_FREQ = 4
......@@ -39,8 +39,7 @@ METHOD = None
def resize_keepdims(im, size):
# Opencv's resize remove the extra dimension for grayscale images.
# We add it back.
# Opencv's resize remove the extra dimension for grayscale images. We add it back.
ret = cv2.resize(im, size)
if im.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
......@@ -65,10 +64,20 @@ def get_player(viz=False, train=False):
class Model(DQNModel):
"""
A DQN model for 2D/3D (image) observations.
"""
def __init__(self):
super(Model, self).__init__(IMAGE_SIZE, IMAGE_CHANNEL, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA)
assert len(STATE_SHAPE) in [2, 3]
super(Model, self).__init__(STATE_SHAPE, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA)
def _get_DQN_prediction(self, image):
assert image.shape.rank in [4, 5], image.shape
# image: N, H, W, (C), Hist
if image.shape.rank == 5:
# merge C & Hist
image = tf.reshape(image, [-1] + list(STATE_SHAPE[:2]) + [STATE_SHAPE[2] * FRAME_HISTORY])
image = image / 255.0
with argscope(Conv2D, activation=lambda x: PReLU('prelu', x), use_bias=True):
l = (LinearWrap(image)
......@@ -102,7 +111,7 @@ def get_config():
expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True),
state_shape=IMAGE_SIZE + (IMAGE_CHANNEL,),
state_shape=STATE_SHAPE,
batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE,
......@@ -152,7 +161,7 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
ENV_NAME = args.env
USE_GYM = not ENV_NAME.endswith('.bin')
IMAGE_CHANNEL = 3 if USE_GYM else 1
STATE_SHAPE = IMAGE_SIZE + (3, ) if USE_GYM else IMAGE_SIZE
METHOD = args.algo
# set num_actions
NUM_ACTIONS = get_player().action_space.n
......
......@@ -12,16 +12,19 @@ from tensorpack.utils import logger
class Model(ModelDesc):
learning_rate = 1e-3
def __init__(self, image_shape, channel, history, method, num_actions, gamma):
assert len(image_shape) == 2, image_shape
state_dtype = tf.uint8
self.channel = channel
self._shape2d = tuple(image_shape)
self._shape3d = self._shape2d + (channel, )
self._shape4d_for_prediction = (-1, ) + self._shape2d + (history * channel, )
self._channel = channel
def __init__(self, state_shape, history, method, num_actions, gamma):
"""
Args:
state_shape (tuple[int]),
history (int):
"""
self._state_shape = tuple(state_shape)
self._stacked_state_shape = (-1, ) + self._state_shape + (history, )
self.history = history
self.method = method
self.num_actions = num_actions
......@@ -31,37 +34,43 @@ class Model(ModelDesc):
# When we use h history frames, the current state and the next state will have (h-1) overlapping frames.
# Therefore we use a combined state for efficiency:
# The first h are the current state, and the last h are the next state.
return [tf.placeholder(tf.uint8,
(None,) + self._shape2d +
((self.history + 1) * self.channel,),
return [tf.placeholder(self.state_dtype,
(None,) + self._state_shape + (self.history + 1, ),
'comb_state'),
tf.placeholder(tf.int64, (None,), 'action'),
tf.placeholder(tf.float32, (None,), 'reward'),
tf.placeholder(tf.bool, (None,), 'isOver')]
@abc.abstractmethod
def _get_DQN_prediction(self, image):
def _get_DQN_prediction(self, state):
"""
state: N + state_shape + history
"""
pass
@auto_reuse_variable_scope
def get_DQN_prediction(self, image):
""" image: [N, H, W, history * C] in [0,255]"""
return self._get_DQN_prediction(image)
def get_DQN_prediction(self, state):
return self._get_DQN_prediction(state)
def build_graph(self, comb_state, action, reward, isOver):
comb_state = tf.cast(comb_state, tf.float32)
comb_state = tf.reshape(
comb_state, [-1] + list(self._shape2d) + [self.history + 1, self.channel])
input_rank = comb_state.shape.rank
state = tf.slice(
comb_state,
[0] * input_rank,
[-1] * (input_rank - 1) + [self.history], name='state')
state = tf.slice(comb_state, [0, 0, 0, 0, 0], [-1, -1, -1, self.history, -1])
state = tf.reshape(state, self._shape4d_for_prediction, name='state')
self.predict_value = self.get_DQN_prediction(state)
if not get_current_tower_context().is_training:
return
reward = tf.clip_by_value(reward, -1, 1)
next_state = tf.slice(comb_state, [0, 0, 0, 1, 0], [-1, -1, -1, self.history, -1], name='next_state')
next_state = tf.reshape(next_state, self._shape4d_for_prediction)
next_state = tf.slice(
comb_state,
[0] * (input_rank - 1) + [1],
[-1] * (input_rank - 1) + [self.history], name='next_state')
next_state = tf.reshape(next_state, self._stacked_state_shape)
action_onehot = tf.one_hot(action, self.num_actions, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N,
......
......@@ -94,7 +94,7 @@ class AtariPlayer(gym.Env):
self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width, 1), dtype=np.uint8)
low=0, high=255, shape=(self.height, self.width), dtype=np.uint8)
self._restart_episode()
def get_action_meanings(self):
......@@ -109,7 +109,7 @@ class AtariPlayer(gym.Env):
def _current_state(self):
"""
:returns: a gray-scale (h, w, 1) uint8 image
:returns: a gray-scale (h, w) uint8 image
"""
ret = self._grab_raw_image()
# max-pooled over the last screen
......@@ -120,7 +120,7 @@ class AtariPlayer(gym.Env):
cv2.waitKey(int(self.viz * 1000))
ret = ret.astype('float32')
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :]
return ret.astype('uint8') # to save some memory
def _restart_episode(self):
......
......@@ -4,7 +4,6 @@
import numpy as np
from collections import deque
import gym
from gym import spaces
_v0, _v1 = gym.__version__.split('.')[:2]
assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__
......@@ -27,17 +26,13 @@ class MapState(gym.ObservationWrapper):
class FrameStack(gym.Wrapper):
"""
Buffer observations and stack across channels (last axis).
The output observation has shape (H, W, History * Channel)
Buffer consecutive k observations and stack them on a new last axis.
The output observation has shape `original_shape + (k, )`.
"""
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], chan * k), dtype=np.uint8)
def reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
......@@ -54,10 +49,7 @@ class FrameStack(gym.Wrapper):
def observation(self):
assert len(self.frames) == self.k
if self.frames[-1].ndim == 2:
return np.stack(self.frames, axis=-1)
else:
return np.concatenate(self.frames, axis=2)
class _FireResetEnv(gym.Wrapper):
......
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu
import multiprocessing
import numpy as np
import random
import time
from six.moves import queue
......@@ -19,7 +21,8 @@ def play_one_episode(env, func, render=False):
"""
Map from observation to action, with 0.01 greedy.
"""
act = func(s[None, :, :, :])[0][0].argmax()
s = np.expand_dims(s, 0) # batch
act = func(s)[0][0].argmax()
if random.random() < 0.01:
spc = env.action_space
act = spc.sample()
......
......@@ -23,19 +23,21 @@ Experience = namedtuple('Experience',
class ReplayMemory(object):
def __init__(self, max_size, state_shape, history_len):
"""
Args:
state_shape (tuple[int]): shape (without history) of state
"""
self.max_size = int(max_size)
self.state_shape = state_shape
assert len(state_shape) == 3, state_shape
# self._state_transpose = list(range(1, len(state_shape) + 1)) + [0]
self._channel = state_shape[2] if len(state_shape) == 3 else 1
self._shape3d = (state_shape[0], state_shape[1], self._channel * (history_len + 1))
assert len(state_shape) in [1, 2, 3], state_shape
self._output_shape = self.state_shape + (history_len + 1, )
self.history_len = int(history_len)
state_shape = (self.max_size,) + state_shape
all_state_shape = (self.max_size,) + state_shape
logger.info("Creating experience replay buffer of {:.1f} GB ... "
"use a smaller buffer if you don't have enough CPU memory.".format(
np.prod(state_shape) / 1024.0**3))
self.state = np.zeros(state_shape, dtype='uint8')
np.prod(all_state_shape) / 1024.0**3))
self.state = np.zeros(all_state_shape, dtype='uint8')
self.action = np.zeros((self.max_size,), dtype='int32')
self.reward = np.zeros((self.max_size,), dtype='float32')
self.isOver = np.zeros((self.max_size,), dtype='bool')
......@@ -70,7 +72,8 @@ class ReplayMemory(object):
def sample(self, idx):
""" return a tuple of (s,r,a,o),
where s is of shape [H, W, (hist_len+1) * channel]"""
where s is of shape self._output_shape, which is
[H, W, (hist_len+1) * channel] if input is (H, W, channel)"""
idx = (self._curr_pos + idx) % self._curr_size
k = self.history_len + 1
if idx + k <= self._curr_size:
......@@ -95,8 +98,8 @@ class ReplayMemory(object):
state = copy.deepcopy(state)
state[:k + 1].fill(0)
break
# move the first dim to the last
state = state.transpose(1, 2, 0, 3).reshape(self._shape3d)
# move the first dim (history) to the last
state = np.moveaxis(state, 0, -1)
return (state, reward[-2], action[-2], isOver[-2])
def _slice(self, arr, start, end):
......@@ -140,13 +143,13 @@ class ExpReplay(DataFlow, Callback):
predictor_io_names (tuple of list of str): input/output names to
predict Q value from state.
player (gym.Env): the player.
state_shape (tuple): h, w, c
state_shape (tuple):
history_len (int): length of history frames to concat. Zero-filled
initial frames.
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
"""
assert len(state_shape) == 3, state_shape
assert len(state_shape) in [1, 2, 3], state_shape
init_memory_size = int(init_memory_size)
for k, v in locals().items():
......@@ -207,7 +210,7 @@ class ExpReplay(DataFlow, Callback):
# build a history state
history = self.mem.recent_state()
history.append(old_s)
history = np.concatenate(history, axis=-1) # H,W,HistxC
history = np.stack(history, axis=-1) # state_shape + (Hist,)
history = np.expand_dims(history, axis=0)
# assume batched network
......@@ -216,7 +219,9 @@ class ExpReplay(DataFlow, Callback):
self._current_ob, reward, isOver, info = self.player.step(act)
self._current_game_score.feed(reward)
if isOver:
if info['ale.lives'] == 0: # only record score when a whole game is over (not when an episode is over)
# handle ale-specific information
if info.get('ale.lives', -1) == 0:
# only record score when a whole game is over (not when an episode is over)
self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset()
self.player.reset()
......@@ -226,6 +231,7 @@ class ExpReplay(DataFlow, Callback):
import cv2
def view_state(comb_state):
# this function assumes comb_state is 3D
state = comb_state[:, :, :-1]
next_state = comb_state[:, :, 1:]
r = np.concatenate([state[:, :, k] for k in range(self.history_len)], axis=1)
......
......@@ -44,10 +44,16 @@ class Callback(object):
_chief_only = True
name_scope = ""
"""
A name scope for ops created inside this callback.
By default to the name of the class, but can be set per-instance.
"""
def setup_graph(self, trainer):
self.trainer = trainer
self.graph = tf.get_default_graph()
scope_name = type(self).__name__
scope_name = self.name_scope or type(self).__name__
scope_name = scope_name.replace('_', '')
with tf.name_scope(scope_name):
self._setup_graph()
......
......@@ -251,11 +251,13 @@ class QueueInput(FeedfreeInput):
# in TF there is no API to get queue capacity, so we can only summary the size
size = tf.cast(self.queue.size(), tf.float32, name='queue_size')
size_ema_op = add_moving_summary(size, collection=None, decay=0.5)[0].op
return RunOp(
ret = RunOp(
lambda: size_ema_op,
run_before=False,
run_as_trigger=False,
run_step=True)
ret.name_scope = "InputSource/EMA"
return ret
def _get_callbacks(self):
from ..callbacks.concurrency import StartProcOrThread
......
......@@ -194,6 +194,7 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
run_before=True,
run_as_trigger=self.BROADCAST_EVERY_EPOCH,
verbose=True)
cb.name_scope = "SyncVariables"
return [cb]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment