Commit 1f3eaf97 authored by Yuxin Wu's avatar Yuxin Wu

bug fix in Double-DQN

parent d6f1b6ee
...@@ -92,24 +92,27 @@ class Model(ModelDesc): ...@@ -92,24 +92,27 @@ class Model(ModelDesc):
def _build_graph(self, inputs, is_training): def _build_graph(self, inputs, is_training):
state, action, reward, next_state, isOver = inputs state, action, reward, next_state, isOver = inputs
self.predict_value = self._get_DQN_prediction(state, is_training) self.predict_value = self._get_DQN_prediction(state, is_training)
action_onehot = tf.one_hot(action, NUM_ACTIONS) action_onehot = tf.one_hot(action, NUM_ACTIONS, 1.0, 0.0)
pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) #N, pred_action_value = tf.reduce_sum(self.predict_value * action_onehot, 1) #N,
max_pred_reward = tf.reduce_mean(tf.reduce_max( max_pred_reward = tf.reduce_mean(tf.reduce_max(
self.predict_value, 1), name='predict_reward') self.predict_value, 1), name='predict_reward')
add_moving_summary(max_pred_reward) add_moving_summary(max_pred_reward)
self.greedy_choice = tf.argmax(self.predict_value, 1) # N,
with tf.variable_scope('target'): with tf.variable_scope('target'):
targetQ_predict_value = self._get_DQN_prediction(next_state, False) # NxA targetQ_predict_value = self._get_DQN_prediction(next_state, False) # NxA
# DQN # DQN
#best_v = tf.reduce_max(targetQ_predict_value, 1) # N, #best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
# Double-DQN # Double-DQN
predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0) tf.get_variable_scope().reuse_variables()
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1) next_predict_value = self._get_DQN_prediction(next_state, is_training)
self.greedy_choice = tf.argmax(next_predict_value, 1) # N,
predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0)
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
sqrcost = tf.square(target - pred_action_value) sqrcost = tf.square(target - pred_action_value)
abscost = tf.abs(target - pred_action_value) # robust error func abscost = tf.abs(target - pred_action_value) # robust error func
......
...@@ -17,6 +17,7 @@ __all__ = ['total_timer', 'timed_operation', ...@@ -17,6 +17,7 @@ __all__ = ['total_timer', 'timed_operation',
'print_total_timer', 'IterSpeedCounter'] 'print_total_timer', 'IterSpeedCounter']
class IterSpeedCounter(object): class IterSpeedCounter(object):
""" To count how often some code gets reached"""
def __init__(self, print_every, name=None): def __init__(self, print_every, name=None):
self.cnt = 0 self.cnt = 0
self.print_every = int(print_every) self.print_every = int(print_every)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment