Commit d3167ba3 authored by Yuxin Wu's avatar Yuxin Wu

add dueling

parent 4f3d4e27
......@@ -23,13 +23,13 @@ import common
from common import play_model, Evaluator, eval_model_multithread
from atari import AtariPlayer
METHOD = ['DQN', 'Double', 'Dueling'][1]
BATCH_SIZE = 64
IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4
ACTION_REPEAT = 4
HEIGHT_RANGE = (None, None)
#HEIGHT_RANGE = (36, 204) # for breakout
#HEIGHT_RANGE = (28, -8) # for pong
CHANNEL = FRAME_HISTORY
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
......@@ -50,9 +50,8 @@ NUM_ACTIONS = None
ROM_FILE = None
def get_player(viz=False, train=False):
pl = AtariPlayer(ROM_FILE, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT, image_shape=IMAGE_SIZE[::-1], viz=viz,
live_lost_as_eoe=train)
pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT,
image_shape=IMAGE_SIZE[::-1], viz=viz, live_lost_as_eoe=train)
global NUM_ACTIONS
NUM_ACTIONS = pl.get_action_space().num_actions()
if not train:
......@@ -76,7 +75,7 @@ class Model(ModelDesc):
""" image: [0,255]"""
image = image / 255.0
with argscope(Conv2D, nl=PReLU.f, use_bias=True):
return (LinearWrap(image)
l = (LinearWrap(image)
.Conv2D('conv0', out_channel=32, kernel_shape=5)
.MaxPooling('pool0', 2)
.Conv2D('conv1', out_channel=32, kernel_shape=5)
......@@ -90,8 +89,14 @@ class Model(ModelDesc):
#.Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2)
#.Conv2D('conv2', out_channel=64, kernel_shape=3)
.FullyConnected('fc0', 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))
.FullyConnected('fct', NUM_ACTIONS, nl=tf.identity)())
.FullyConnected('fc0', 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))())
if METHOD != 'Dueling':
Q = FullyConnected('fct', l, NUM_ACTIONS, nl=tf.identity)
else:
V = FullyConnected('fctV', l, 1, nl=tf.identity)
As = FullyConnected('fctA', l, NUM_ACTIONS, nl=tf.identity)
Q = tf.add(As, V - tf.reduce_mean(As, 1, keep_dims=True))
return tf.identity(Q, name='Qvalue')
def _build_graph(self, inputs):
state, action, reward, next_state, isOver = inputs
......@@ -105,22 +110,22 @@ class Model(ModelDesc):
with tf.variable_scope('target'):
targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA
# DQN
#best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
# Double-DQN
tf.get_variable_scope().reuse_variables()
next_predict_value = self._get_DQN_prediction(next_state)
self.greedy_choice = tf.argmax(next_predict_value, 1) # N,
predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0)
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
if METHOD != 'Double':
# DQN
best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
else:
# Double-DQN
tf.get_variable_scope().reuse_variables()
next_predict_value = self._get_DQN_prediction(next_state)
self.greedy_choice = tf.argmax(next_predict_value, 1) # N,
predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0)
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
cost = symbf.huber_loss(target - pred_action_value)
self.cost = symbf.huber_loss(target - pred_action_value, name='cost')
summary.add_param_summary([('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms']) ]) # monitor all W
self.cost = tf.reduce_mean(cost, name='cost')
def update_target_param(self):
vars = tf.trainable_variables()
......@@ -134,8 +139,7 @@ class Model(ModelDesc):
return tf.group(*ops, name='update_target_network')
def get_gradient_processor(self):
return [MapGradient(lambda grad: \
tf.clip_by_global_norm([grad], 5)[0][0]),
return [MapGradient(lambda grad: tf.clip_by_global_norm([grad], 5)[0][0]),
SummaryGradient()]
def get_config():
......@@ -143,7 +147,7 @@ def get_config():
M = Model()
dataset_train = ExpReplay(
predictor_io_names=(['state'], ['fct/output']),
predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True),
batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE,
......@@ -167,7 +171,7 @@ def get_config():
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
RunOp(lambda: M.update_target_param()),
dataset_train,
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['fct/output']), 3),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['Qvalue']), 3),
#HumanHyperParamSetter('learning_rate', 'hyper.txt'),
#HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
]),
......@@ -197,7 +201,7 @@ if __name__ == '__main__':
model=Model(),
session_init=SaverRestore(args.load),
input_var_names=['state'],
output_var_names=['fct/output:0'])
output_var_names=['Qvalue'])
if args.task == 'play':
play_model(cfg)
elif args.task == 'eval':
......
......@@ -2,7 +2,7 @@
[video demo](https://youtu.be/o21mddZtE5Y)
Reproduce the following reinforcement learning methods:
Reproduce the following reinforcement learning papers:
+ Nature-DQN in:
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
......@@ -10,6 +10,8 @@ Reproduce the following reinforcement learning methods:
+ Double-DQN in:
[Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461)
+ Dueling-DQN in: [Dueling Network Architectures for Deep Reinforcement Learning](https://arxiv.org/abs/1511.06581)
+ A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). (I
used a modified version where each batch contains transitions from different simulators, which I called "Batch-A3C".)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment