Commit 0870401c authored by Yuxin Wu's avatar Yuxin Wu

speedup DQN.

parent cab0c4f3
...@@ -20,7 +20,6 @@ from collections import deque ...@@ -20,7 +20,6 @@ from collections import deque
from tensorpack import * from tensorpack import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.RL import * from tensorpack.RL import *
import common import common
...@@ -34,7 +33,6 @@ FRAME_HISTORY = 4 ...@@ -34,7 +33,6 @@ FRAME_HISTORY = 4
ACTION_REPEAT = 4 ACTION_REPEAT = 4
CHANNEL = FRAME_HISTORY CHANNEL = FRAME_HISTORY
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
GAMMA = 0.99 GAMMA = 0.99
INIT_EXPLORATION = 1 INIT_EXPLORATION = 1
...@@ -59,6 +57,7 @@ def get_player(viz=False, train=False): ...@@ -59,6 +57,7 @@ def get_player(viz=False, train=False):
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = pl.get_action_space().num_actions() NUM_ACTIONS = pl.get_action_space().num_actions()
if not train: if not train:
pl = MapPlayerState(pl, lambda im: im[:, :, np.newaxis])
pl = HistoryFramePlayer(pl, FRAME_HISTORY) pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1) pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 30000) pl = LimitLengthPlayer(pl, 30000)
...@@ -73,10 +72,11 @@ class Model(ModelDesc): ...@@ -73,10 +72,11 @@ class Model(ModelDesc):
if NUM_ACTIONS is None: if NUM_ACTIONS is None:
p = get_player() p = get_player()
del p del p
return [InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'state'), return [InputDesc(tf.uint8,
(None,) + IMAGE_SIZE + (CHANNEL + 1,),
'comb_state'),
InputDesc(tf.int64, (None,), 'action'), InputDesc(tf.int64, (None,), 'action'),
InputDesc(tf.float32, (None,), 'reward'), InputDesc(tf.float32, (None,), 'reward'),
InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'next_state'),
InputDesc(tf.bool, (None,), 'isOver')] InputDesc(tf.bool, (None,), 'isOver')]
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
...@@ -108,13 +108,20 @@ class Model(ModelDesc): ...@@ -108,13 +108,20 @@ class Model(ModelDesc):
return tf.identity(Q, name='Qvalue') return tf.identity(Q, name='Qvalue')
def _build_graph(self, inputs): def _build_graph(self, inputs):
state, action, reward, next_state, isOver = inputs ctx = get_current_tower_context()
comb_state, action, reward, isOver = inputs
comb_state = tf.cast(comb_state, tf.float32)
state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, 4], name='state')
self.predict_value = self._get_DQN_prediction(state) self.predict_value = self._get_DQN_prediction(state)
if not ctx.is_training:
return
next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, 4], name='next_state')
action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0) action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N, pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N,
max_pred_reward = tf.reduce_mean(tf.reduce_max( max_pred_reward = tf.reduce_mean(tf.reduce_max(
self.predict_value, 1), name='predict_reward') self.predict_value, 1), name='predict_reward')
add_moving_summary(max_pred_reward) summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'): with tf.variable_scope('target'):
targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA
...@@ -137,7 +144,7 @@ class Model(ModelDesc): ...@@ -137,7 +144,7 @@ class Model(ModelDesc):
target - pred_action_value), name='cost') target - pred_action_value), name='cost')
summary.add_param_summary(('conv.*/W', ['histogram', 'rms']), summary.add_param_summary(('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms'])) # monitor all W ('fc.*/W', ['histogram', 'rms'])) # monitor all W
add_moving_summary(self.cost) summary.add_moving_summary(self.cost)
def update_target_param(self): def update_target_param(self):
vars = tf.trainable_variables() vars = tf.trainable_variables()
...@@ -164,6 +171,7 @@ def get_config(): ...@@ -164,6 +171,7 @@ def get_config():
expreplay = ExpReplay( expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']), predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True), player=get_player(train=True),
state_shape=IMAGE_SIZE,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE, init_memory_size=INIT_MEMORY_SIZE,
...@@ -171,8 +179,9 @@ def get_config(): ...@@ -171,8 +179,9 @@ def get_config():
end_exploration=END_EXPLORATION, end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL, exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
update_frequency=4, update_frequency=4,
reward_clip=(-1, 1), history_len=FRAME_HISTORY,
history_len=FRAME_HISTORY) reward_clip=(-1, 1)
)
return TrainConfig( return TrainConfig(
dataflow=expreplay, dataflow=expreplay,
...@@ -215,7 +224,7 @@ if __name__ == '__main__': ...@@ -215,7 +224,7 @@ if __name__ == '__main__':
if args.task != 'train': if args.task != 'train':
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
session_init=SaverRestore(args.load), session_init=get_model_loader(args.load),
input_names=['state'], input_names=['state'],
output_names=['Qvalue']) output_names=['Qvalue'])
if args.task == 'play': if args.task == 'play':
......
...@@ -106,7 +106,7 @@ class AtariPlayer(RLEnvironment): ...@@ -106,7 +106,7 @@ class AtariPlayer(RLEnvironment):
def current_state(self): def current_state(self):
""" """
:returns: a gray-scale (h, w, 1) uint8 image :returns: a gray-scale (h, w) uint8 image
""" """
ret = self._grab_raw_image() ret = self._grab_raw_image()
# max-pooled over the last screen # max-pooled over the last screen
...@@ -119,7 +119,6 @@ class AtariPlayer(RLEnvironment): ...@@ -119,7 +119,6 @@ class AtariPlayer(RLEnvironment):
# 0.299,0.587.0.114. same as rgb2y in torch/image # 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.resize(ret, self.image_shape) ret = cv2.resize(ret, self.image_shape)
ret = np.expand_dims(ret, axis=2)
return ret.astype('uint8') # to save some memory return ret.astype('uint8') # to save some memory
def get_action_space(self): def get_action_space(self):
......
...@@ -90,7 +90,7 @@ def eval_with_funcs(predict_funcs, nr_eval): ...@@ -90,7 +90,7 @@ def eval_with_funcs(predict_funcs, nr_eval):
def eval_model_multithread(cfg, nr_eval): def eval_model_multithread(cfg, nr_eval):
func = get_predict_func(cfg) func = OfflinePredictor(cfg)
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC, nr_eval) mean, max = eval_with_funcs([func] * NR_PROC, nr_eval)
logger.info("Average Score: {}; Max Score: {}".format(mean, max)) logger.info("Average Score: {}; Max Score: {}".format(mean, max))
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np import numpy as np
import copy
from collections import deque, namedtuple from collections import deque, namedtuple
import threading import threading
import six import six
from six.moves import queue from six.moves import queue, range
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger, get_tqdm, get_rng from tensorpack.utils import logger, get_tqdm, get_rng
...@@ -20,6 +21,89 @@ Experience = namedtuple('Experience', ...@@ -20,6 +21,89 @@ Experience = namedtuple('Experience',
['state', 'action', 'reward', 'isOver']) ['state', 'action', 'reward', 'isOver'])
class ReplayMemory(object):
def __init__(self, max_size, state_shape, history_len):
self.max_size = int(max_size)
self.state_shape = state_shape
self.history_len = int(history_len)
self.state = np.zeros((self.max_size,) + state_shape, dtype='uint8')
self.action = np.zeros((self.max_size,), dtype='int32')
self.reward = np.zeros((self.max_size,), dtype='float32')
self.isOver = np.zeros((self.max_size,), dtype='bool')
self._curr_size = 0
self._curr_pos = 0
self._hist = deque(maxlen=history_len - 1)
def append(self, exp):
"""
Args:
exp (Experience):
"""
if self._curr_size < self.max_size:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
self._curr_size += 1
else:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
if exp.isOver:
self._hist.clear()
else:
self._hist.append(exp)
def recent_state(self):
""" return a list of (hist_len-1,) + STATE_SIZE """
lst = list(self._hist)
states = [np.zeros(self.state_shape, dtype='uint8')] * (self._hist.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
def sample(self, idx):
""" return a tuple of (s,r,a,o),
where s is of shape STATE_SIZE + (hist_len+1,)"""
idx = (self._curr_pos + idx) % self._curr_size
k = self.history_len + 1
if idx + k <= self._curr_size:
state = self.state[idx: idx + k]
reward = self.reward[idx: idx + k]
action = self.action[idx: idx + k]
isOver = self.isOver[idx: idx + k]
else:
end = idx + k - self._curr_size
state = self._slice(self.state, idx, end)
reward = self._slice(self.reward, idx, end)
action = self._slice(self.action, idx, end)
isOver = self._slice(self.isOver, idx, end)
ret = self._pad_sample(state, reward, action, isOver)
return ret
# the next_state is a different episode if current_state.isOver==True
def _pad_sample(self, state, reward, action, isOver):
for k in range(self.history_len - 2, -1, -1):
if isOver[k]:
state = copy.deepcopy(state)
state[:k + 1].fill(0)
break
state = state.transpose(1, 2, 0)
return (state, reward[-2], action[-2], isOver[-2])
def _slice(self, arr, start, end):
s1 = arr[start:]
s2 = arr[:end]
return np.concatenate((s1, s2), axis=0)
def __len__(self):
return self._curr_size
def _assign(self, pos, exp):
self.state[pos] = exp.state
self.reward[pos] = exp.reward
self.action[pos] = exp.action
self.isOver[pos] = exp.isOver
class ExpReplay(DataFlow, Callback): class ExpReplay(DataFlow, Callback):
""" """
Implement experience replay in the paper Implement experience replay in the paper
...@@ -36,16 +120,12 @@ class ExpReplay(DataFlow, Callback): ...@@ -36,16 +120,12 @@ class ExpReplay(DataFlow, Callback):
def __init__(self, def __init__(self,
predictor_io_names, predictor_io_names,
player, player,
batch_size=32, state_shape,
memory_size=1e6, batch_size,
init_memory_size=50000, memory_size, init_memory_size,
exploration=1, exploration, end_exploration, exploration_epoch_anneal,
end_exploration=0.1, update_frequency, history_len,
exploration_epoch_anneal=0.002, reward_clip=None):
reward_clip=None,
update_frequency=1,
history_len=1
):
""" """
Args: Args:
predictor_io_names (tuple of list of str): input/output names to predictor_io_names (tuple of list of str): input/output names to
...@@ -63,7 +143,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -63,7 +143,7 @@ class ExpReplay(DataFlow, Callback):
setattr(self, k, v) setattr(self, k, v)
self.num_actions = player.get_action_space().num_actions() self.num_actions = player.get_action_space().num_actions()
logger.info("Number of Legal actions: {}".format(self.num_actions)) logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=int(memory_size))
self.rng = get_rng(self) self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized self._init_memory_flag = threading.Event() # tell if memory has been initialized
...@@ -71,8 +151,10 @@ class ExpReplay(DataFlow, Callback): ...@@ -71,8 +151,10 @@ class ExpReplay(DataFlow, Callback):
# a queue to receive notifications to populate memory # a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5) self._populate_job_queue = queue.Queue(maxsize=5)
self.mem = ReplayMemory(memory_size, state_shape, history_len)
def get_simulator_thread(self): def get_simulator_thread(self):
# spawn a separate thread to run policy, can speed up 1.3x # spawn a separate thread to run policy
def populate_job_func(): def populate_job_func():
self._populate_job_queue.get() self._populate_job_queue.get()
for _ in range(self.update_frequency): for _ in range(self.update_frequency):
...@@ -84,10 +166,6 @@ class ExpReplay(DataFlow, Callback): ...@@ -84,10 +166,6 @@ class ExpReplay(DataFlow, Callback):
def _init_memory(self): def _init_memory(self):
logger.info("Populating replay memory with epsilon={} ...".format(self.exploration)) logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))
# fill some for the history
for k in range(self.history_len):
self._populate_exp()
with get_tqdm(total=self.init_memory_size) as pbar: with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size: while len(self.mem) < self.init_memory_size:
self._populate_exp() self._populate_exp()
...@@ -96,108 +174,69 @@ class ExpReplay(DataFlow, Callback): ...@@ -96,108 +174,69 @@ class ExpReplay(DataFlow, Callback):
def _populate_exp(self): def _populate_exp(self):
""" populate a transition by epsilon-greedy""" """ populate a transition by epsilon-greedy"""
# if len(self.mem): # if len(self.mem) > 4 and not self._init_memory_flag.is_set():
# from copy import deepcopy # quickly fill the memory for debug # from copy import deepcopy # quickly fill the memory for debug
# self.mem.append(deepcopy(self.mem[0])) # self.mem.append(deepcopy(self.mem._hist[0]))
# return # return
old_s = self.player.current_state() old_s = self.player.current_state()
if self.rng.rand() <= self.exploration or len(self.mem) < 5: if self.rng.rand() <= self.exploration or len(self.mem) < 5:
act = self.rng.choice(range(self.num_actions)) act = self.rng.choice(range(self.num_actions))
else: else:
# build a history state # build a history state
# assume a state can be representated by one tensor history = self.mem.recent_state()
ss = [old_s] history.append(old_s)
history = np.stack(history, axis=2)
isOver = False
for k in range(1, self.history_len):
hist_exp = self.mem[-k]
if hist_exp.isOver:
isOver = True
if isOver:
# fill the beginning of an episode with zeros
ss.append(np.zeros_like(ss[0]))
else:
ss.append(hist_exp.state)
ss.reverse()
ss = np.concatenate(ss, axis=2)
# assume batched network # assume batched network
q_values = self.predictor([[ss]])[0][0] q_values = self.predictor([[history]])[0][0]
act = np.argmax(q_values) act = np.argmax(q_values)
reward, isOver = self.player.action(act) reward, isOver = self.player.action(act)
if self.reward_clip: if self.reward_clip:
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1]) reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
self.mem.append(Experience(old_s, act, reward, isOver)) self.mem.append(Experience(old_s, act, reward, isOver))
def debug_sample(self, sample):
import cv2
def view_state(comb_state):
state = comb_state[:, :, :-1]
next_state = comb_state[:, :, 1:]
r = np.concatenate([state[:, :, k] for k in range(self.history_len)], axis=1)
r2 = np.concatenate([next_state[:, :, k] for k in range(self.history_len)], axis=1)
r = np.concatenate([r, r2], axis=0)
cv2.imshow("state", r)
cv2.waitKey()
print("Act: ", sample[2], " reward:", sample[1], " isOver: ", sample[3])
if sample[1] or sample[3]:
view_state(sample[0])
def get_data(self): def get_data(self):
# wait for memory to be initialized # wait for memory to be initialized
self._init_memory_flag.wait() self._init_memory_flag.wait()
while True: while True:
batch_exp = [self._sample_one() for _ in range(self.batch_size)] idx = self.rng.randint(
self._populate_job_queue.maxsize * self.update_frequency,
# import cv2 # for debug len(self.mem) - self.history_len - 1,
# def view_state(state, next_state): size=self.batch_size)
# """ for debugging state representation""" batch_exp = [self.mem.sample(i) for i in idx]
# r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
# r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
# r = np.concatenate([r, r2], axis=0)
# print r.shape
# cv2.imshow("state", r)
# cv2.waitKey()
# exp = batch_exp[0]
# print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
# if exp[2] or exp[4]:
# view_state(exp[0], exp[1])
yield self._process_batch(batch_exp) yield self._process_batch(batch_exp)
self._populate_job_queue.put(1) self._populate_job_queue.put(1)
# new state is considered useless if isOver==True
def _sample_one(self):
""" return the transition tuple for
[idx, idx+history_len) -> [idx+1, idx+1+history_len)
it's the transition from state idx+history_len-1 to state idx+history_len
"""
# look for a state to start with
# when x.isOver==True, (x+1).state is of a different episode
idx = self.rng.randint(len(self.mem) - self.history_len - 1)
samples = [self.mem[k] for k in range(idx, idx + self.history_len + 1)]
def concat(idx):
v = [x.state for x in samples[idx:idx + self.history_len]]
return np.concatenate(v, axis=2)
state = concat(0)
next_state = concat(1)
start_mem = samples[-2]
reward, action, isOver = start_mem.reward, start_mem.action, start_mem.isOver
start_idx = self.history_len - 1
# zero-fill state before starting
zero_fill = False
for k in range(1, self.history_len):
if samples[start_idx - k].isOver:
zero_fill = True
if zero_fill:
state[:, :, -k - 1] = 0
if k + 2 <= self.history_len:
next_state[:, :, -k - 2] = 0
return (state, next_state, reward, action, isOver)
def _process_batch(self, batch_exp): def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp]) state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
next_state = np.asarray([e[1] for e in batch_exp]) reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
reward = np.asarray([e[2] for e in batch_exp]) action = np.asarray([e[2] for e in batch_exp], dtype='int8')
action = np.asarray([e[3] for e in batch_exp], dtype='int8') isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
isOver = np.asarray([e[4] for e in batch_exp], dtype='bool') return [state, action, reward, isOver]
return [state, action, reward, next_state, isOver]
def _setup_graph(self): def _setup_graph(self):
self.predictor = self.trainer.get_predict_func(*self.predictor_io_names) self.predictor = self.trainer.get_predict_func(*self.predictor_io_names)
def _before_train(self): def _before_train(self):
self._init_memory() self._init_memory()
# TODO start thread here
def _trigger_epoch(self): def _trigger_epoch(self):
if self.exploration > self.end_exploration: if self.exploration > self.end_exploration:
......
...@@ -116,7 +116,11 @@ class ShareSessionThread(threading.Thread): ...@@ -116,7 +116,11 @@ class ShareSessionThread(threading.Thread):
@contextmanager @contextmanager
def default_sess(self): def default_sess(self):
with self._sess.as_default(): if self._sess:
with self._sess.as_default():
yield
else:
logger.warn("ShareSessionThread {} wasn't under a default session!".format(self.name))
yield yield
def start(self): def start(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment