Commit 0870401c authored by Yuxin Wu's avatar Yuxin Wu

speedup DQN.

parent cab0c4f3
...@@ -20,7 +20,6 @@ from collections import deque ...@@ -20,7 +20,6 @@ from collections import deque
from tensorpack import * from tensorpack import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.RL import * from tensorpack.RL import *
import common import common
...@@ -34,7 +33,6 @@ FRAME_HISTORY = 4 ...@@ -34,7 +33,6 @@ FRAME_HISTORY = 4
ACTION_REPEAT = 4 ACTION_REPEAT = 4
CHANNEL = FRAME_HISTORY CHANNEL = FRAME_HISTORY
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
GAMMA = 0.99 GAMMA = 0.99
INIT_EXPLORATION = 1 INIT_EXPLORATION = 1
...@@ -59,6 +57,7 @@ def get_player(viz=False, train=False): ...@@ -59,6 +57,7 @@ def get_player(viz=False, train=False):
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = pl.get_action_space().num_actions() NUM_ACTIONS = pl.get_action_space().num_actions()
if not train: if not train:
pl = MapPlayerState(pl, lambda im: im[:, :, np.newaxis])
pl = HistoryFramePlayer(pl, FRAME_HISTORY) pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1) pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 30000) pl = LimitLengthPlayer(pl, 30000)
...@@ -73,10 +72,11 @@ class Model(ModelDesc): ...@@ -73,10 +72,11 @@ class Model(ModelDesc):
if NUM_ACTIONS is None: if NUM_ACTIONS is None:
p = get_player() p = get_player()
del p del p
return [InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'state'), return [InputDesc(tf.uint8,
(None,) + IMAGE_SIZE + (CHANNEL + 1,),
'comb_state'),
InputDesc(tf.int64, (None,), 'action'), InputDesc(tf.int64, (None,), 'action'),
InputDesc(tf.float32, (None,), 'reward'), InputDesc(tf.float32, (None,), 'reward'),
InputDesc(tf.float32, (None,) + IMAGE_SHAPE3, 'next_state'),
InputDesc(tf.bool, (None,), 'isOver')] InputDesc(tf.bool, (None,), 'isOver')]
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
...@@ -108,13 +108,20 @@ class Model(ModelDesc): ...@@ -108,13 +108,20 @@ class Model(ModelDesc):
return tf.identity(Q, name='Qvalue') return tf.identity(Q, name='Qvalue')
def _build_graph(self, inputs): def _build_graph(self, inputs):
state, action, reward, next_state, isOver = inputs ctx = get_current_tower_context()
comb_state, action, reward, isOver = inputs
comb_state = tf.cast(comb_state, tf.float32)
state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, 4], name='state')
self.predict_value = self._get_DQN_prediction(state) self.predict_value = self._get_DQN_prediction(state)
if not ctx.is_training:
return
next_state = tf.slice(comb_state, [0, 0, 0, 1], [-1, -1, -1, 4], name='next_state')
action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0) action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N, pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) # N,
max_pred_reward = tf.reduce_mean(tf.reduce_max( max_pred_reward = tf.reduce_mean(tf.reduce_max(
self.predict_value, 1), name='predict_reward') self.predict_value, 1), name='predict_reward')
add_moving_summary(max_pred_reward) summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'): with tf.variable_scope('target'):
targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA
...@@ -137,7 +144,7 @@ class Model(ModelDesc): ...@@ -137,7 +144,7 @@ class Model(ModelDesc):
target - pred_action_value), name='cost') target - pred_action_value), name='cost')
summary.add_param_summary(('conv.*/W', ['histogram', 'rms']), summary.add_param_summary(('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms'])) # monitor all W ('fc.*/W', ['histogram', 'rms'])) # monitor all W
add_moving_summary(self.cost) summary.add_moving_summary(self.cost)
def update_target_param(self): def update_target_param(self):
vars = tf.trainable_variables() vars = tf.trainable_variables()
...@@ -164,6 +171,7 @@ def get_config(): ...@@ -164,6 +171,7 @@ def get_config():
expreplay = ExpReplay( expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']), predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True), player=get_player(train=True),
state_shape=IMAGE_SIZE,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE, init_memory_size=INIT_MEMORY_SIZE,
...@@ -171,8 +179,9 @@ def get_config(): ...@@ -171,8 +179,9 @@ def get_config():
end_exploration=END_EXPLORATION, end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL, exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
update_frequency=4, update_frequency=4,
reward_clip=(-1, 1), history_len=FRAME_HISTORY,
history_len=FRAME_HISTORY) reward_clip=(-1, 1)
)
return TrainConfig( return TrainConfig(
dataflow=expreplay, dataflow=expreplay,
...@@ -215,7 +224,7 @@ if __name__ == '__main__': ...@@ -215,7 +224,7 @@ if __name__ == '__main__':
if args.task != 'train': if args.task != 'train':
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
session_init=SaverRestore(args.load), session_init=get_model_loader(args.load),
input_names=['state'], input_names=['state'],
output_names=['Qvalue']) output_names=['Qvalue'])
if args.task == 'play': if args.task == 'play':
......
...@@ -106,7 +106,7 @@ class AtariPlayer(RLEnvironment): ...@@ -106,7 +106,7 @@ class AtariPlayer(RLEnvironment):
def current_state(self): def current_state(self):
""" """
:returns: a gray-scale (h, w, 1) uint8 image :returns: a gray-scale (h, w) uint8 image
""" """
ret = self._grab_raw_image() ret = self._grab_raw_image()
# max-pooled over the last screen # max-pooled over the last screen
...@@ -119,7 +119,6 @@ class AtariPlayer(RLEnvironment): ...@@ -119,7 +119,6 @@ class AtariPlayer(RLEnvironment):
# 0.299,0.587.0.114. same as rgb2y in torch/image # 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.resize(ret, self.image_shape) ret = cv2.resize(ret, self.image_shape)
ret = np.expand_dims(ret, axis=2)
return ret.astype('uint8') # to save some memory return ret.astype('uint8') # to save some memory
def get_action_space(self): def get_action_space(self):
......
...@@ -90,7 +90,7 @@ def eval_with_funcs(predict_funcs, nr_eval): ...@@ -90,7 +90,7 @@ def eval_with_funcs(predict_funcs, nr_eval):
def eval_model_multithread(cfg, nr_eval): def eval_model_multithread(cfg, nr_eval):
func = get_predict_func(cfg) func = OfflinePredictor(cfg)
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC, nr_eval) mean, max = eval_with_funcs([func] * NR_PROC, nr_eval)
logger.info("Average Score: {}; Max Score: {}".format(mean, max)) logger.info("Average Score: {}; Max Score: {}".format(mean, max))
......
This diff is collapsed.
...@@ -116,7 +116,11 @@ class ShareSessionThread(threading.Thread): ...@@ -116,7 +116,11 @@ class ShareSessionThread(threading.Thread):
@contextmanager @contextmanager
def default_sess(self): def default_sess(self):
with self._sess.as_default(): if self._sess:
with self._sess.as_default():
yield
else:
logger.warn("ShareSessionThread {} wasn't under a default session!".format(self.name))
yield yield
def start(self): def start(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment