Commit d3167ba3 authored by Yuxin Wu's avatar Yuxin Wu

add dueling

parent 4f3d4e27
...@@ -23,13 +23,13 @@ import common ...@@ -23,13 +23,13 @@ import common
from common import play_model, Evaluator, eval_model_multithread from common import play_model, Evaluator, eval_model_multithread
from atari import AtariPlayer from atari import AtariPlayer
METHOD = ['DQN', 'Double', 'Dueling'][1]
BATCH_SIZE = 64 BATCH_SIZE = 64
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 4 ACTION_REPEAT = 4
HEIGHT_RANGE = (None, None)
#HEIGHT_RANGE = (36, 204) # for breakout
#HEIGHT_RANGE = (28, -8) # for pong
CHANNEL = FRAME_HISTORY CHANNEL = FRAME_HISTORY
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,) IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
...@@ -50,9 +50,8 @@ NUM_ACTIONS = None ...@@ -50,9 +50,8 @@ NUM_ACTIONS = None
ROM_FILE = None ROM_FILE = None
def get_player(viz=False, train=False): def get_player(viz=False, train=False):
pl = AtariPlayer(ROM_FILE, height_range=HEIGHT_RANGE, pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT,
frame_skip=ACTION_REPEAT, image_shape=IMAGE_SIZE[::-1], viz=viz, image_shape=IMAGE_SIZE[::-1], viz=viz, live_lost_as_eoe=train)
live_lost_as_eoe=train)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = pl.get_action_space().num_actions() NUM_ACTIONS = pl.get_action_space().num_actions()
if not train: if not train:
...@@ -76,7 +75,7 @@ class Model(ModelDesc): ...@@ -76,7 +75,7 @@ class Model(ModelDesc):
""" image: [0,255]""" """ image: [0,255]"""
image = image / 255.0 image = image / 255.0
with argscope(Conv2D, nl=PReLU.f, use_bias=True): with argscope(Conv2D, nl=PReLU.f, use_bias=True):
return (LinearWrap(image) l = (LinearWrap(image)
.Conv2D('conv0', out_channel=32, kernel_shape=5) .Conv2D('conv0', out_channel=32, kernel_shape=5)
.MaxPooling('pool0', 2) .MaxPooling('pool0', 2)
.Conv2D('conv1', out_channel=32, kernel_shape=5) .Conv2D('conv1', out_channel=32, kernel_shape=5)
...@@ -90,8 +89,14 @@ class Model(ModelDesc): ...@@ -90,8 +89,14 @@ class Model(ModelDesc):
#.Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2) #.Conv2D('conv1', out_channel=64, kernel_shape=4, stride=2)
#.Conv2D('conv2', out_channel=64, kernel_shape=3) #.Conv2D('conv2', out_channel=64, kernel_shape=3)
.FullyConnected('fc0', 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name)) .FullyConnected('fc0', 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))())
.FullyConnected('fct', NUM_ACTIONS, nl=tf.identity)()) if METHOD != 'Dueling':
Q = FullyConnected('fct', l, NUM_ACTIONS, nl=tf.identity)
else:
V = FullyConnected('fctV', l, 1, nl=tf.identity)
As = FullyConnected('fctA', l, NUM_ACTIONS, nl=tf.identity)
Q = tf.add(As, V - tf.reduce_mean(As, 1, keep_dims=True))
return tf.identity(Q, name='Qvalue')
def _build_graph(self, inputs): def _build_graph(self, inputs):
state, action, reward, next_state, isOver = inputs state, action, reward, next_state, isOver = inputs
...@@ -105,9 +110,10 @@ class Model(ModelDesc): ...@@ -105,9 +110,10 @@ class Model(ModelDesc):
with tf.variable_scope('target'): with tf.variable_scope('target'):
targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA
if METHOD != 'Double':
# DQN # DQN
#best_v = tf.reduce_max(targetQ_predict_value, 1) # N, best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
else:
# Double-DQN # Double-DQN
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
next_predict_value = self._get_DQN_prediction(next_state) next_predict_value = self._get_DQN_prediction(next_state)
...@@ -117,10 +123,9 @@ class Model(ModelDesc): ...@@ -117,10 +123,9 @@ class Model(ModelDesc):
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v) target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
cost = symbf.huber_loss(target - pred_action_value) self.cost = symbf.huber_loss(target - pred_action_value, name='cost')
summary.add_param_summary([('conv.*/W', ['histogram', 'rms']), summary.add_param_summary([('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms']) ]) # monitor all W ('fc.*/W', ['histogram', 'rms']) ]) # monitor all W
self.cost = tf.reduce_mean(cost, name='cost')
def update_target_param(self): def update_target_param(self):
vars = tf.trainable_variables() vars = tf.trainable_variables()
...@@ -134,8 +139,7 @@ class Model(ModelDesc): ...@@ -134,8 +139,7 @@ class Model(ModelDesc):
return tf.group(*ops, name='update_target_network') return tf.group(*ops, name='update_target_network')
def get_gradient_processor(self): def get_gradient_processor(self):
return [MapGradient(lambda grad: \ return [MapGradient(lambda grad: tf.clip_by_global_norm([grad], 5)[0][0]),
tf.clip_by_global_norm([grad], 5)[0][0]),
SummaryGradient()] SummaryGradient()]
def get_config(): def get_config():
...@@ -143,7 +147,7 @@ def get_config(): ...@@ -143,7 +147,7 @@ def get_config():
M = Model() M = Model()
dataset_train = ExpReplay( dataset_train = ExpReplay(
predictor_io_names=(['state'], ['fct/output']), predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True), player=get_player(train=True),
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
...@@ -167,7 +171,7 @@ def get_config(): ...@@ -167,7 +171,7 @@ def get_config():
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]), [(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
RunOp(lambda: M.update_target_param()), RunOp(lambda: M.update_target_param()),
dataset_train, dataset_train,
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['fct/output']), 3), PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['Qvalue']), 3),
#HumanHyperParamSetter('learning_rate', 'hyper.txt'), #HumanHyperParamSetter('learning_rate', 'hyper.txt'),
#HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'), #HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
]), ]),
...@@ -197,7 +201,7 @@ if __name__ == '__main__': ...@@ -197,7 +201,7 @@ if __name__ == '__main__':
model=Model(), model=Model(),
session_init=SaverRestore(args.load), session_init=SaverRestore(args.load),
input_var_names=['state'], input_var_names=['state'],
output_var_names=['fct/output:0']) output_var_names=['Qvalue'])
if args.task == 'play': if args.task == 'play':
play_model(cfg) play_model(cfg)
elif args.task == 'eval': elif args.task == 'eval':
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
[video demo](https://youtu.be/o21mddZtE5Y) [video demo](https://youtu.be/o21mddZtE5Y)
Reproduce the following reinforcement learning methods: Reproduce the following reinforcement learning papers:
+ Nature-DQN in: + Nature-DQN in:
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html) [Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
...@@ -10,6 +10,8 @@ Reproduce the following reinforcement learning methods: ...@@ -10,6 +10,8 @@ Reproduce the following reinforcement learning methods:
+ Double-DQN in: + Double-DQN in:
[Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461) [Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461)
+ Dueling-DQN in: [Dueling Network Architectures for Deep Reinforcement Learning](https://arxiv.org/abs/1511.06581)
+ A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). (I + A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). (I
used a modified version where each batch contains transitions from different simulators, which I called "Batch-A3C".) used a modified version where each batch contains transitions from different simulators, which I called "Batch-A3C".)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment