Commit 27d73303 authored by Yuxin Wu's avatar Yuxin Wu

fix 'logits' naming in A3C (fix #197)

parent 7e2be137
......@@ -64,7 +64,7 @@ class Model(ModelDesc):
def _build_graph(self, inputs):
state, action, futurereward = inputs
policy = self._get_NN_prediction(state)
self.logits = tf.nn.softmax(policy, name='logits')
policy = tf.nn.softmax(policy, name='policy')
def run_submission(cfg, output, nr):
......@@ -105,5 +105,5 @@ if __name__ == '__main__':
model=Model(),
session_init=SaverRestore(args.load),
input_names=['state'],
output_names=['logits'])
output_names=['policy'])
run_submission(cfg, args.output, args.episode)
......@@ -96,30 +96,30 @@ class Model(ModelDesc):
l = FullyConnected('fc0', l, 512, nl=tf.identity)
l = PReLU('prelu', l)
policy = FullyConnected('fc-pi', l, out_dim=NUM_ACTIONS, nl=tf.identity)
logits = FullyConnected('fc-pi', l, out_dim=NUM_ACTIONS, nl=tf.identity) # unnormalized policy
value = FullyConnected('fc-v', l, 1, nl=tf.identity)
return policy, value
return logits, value
def _build_graph(self, inputs):
state, action, futurereward = inputs
policy, self.value = self._get_NN_prediction(state)
logits, self.value = self._get_NN_prediction(state)
self.value = tf.squeeze(self.value, [1], name='pred_value') # (B,)
self.logits = tf.nn.softmax(policy, name='logits')
self.policy = tf.nn.softmax(logits, name='policy')
expf = tf.get_variable('explore_factor', shape=[],
initializer=tf.constant_initializer(1), trainable=False)
logitsT = tf.nn.softmax(policy * expf, name='logitsT')
policy_explore = tf.nn.softmax(logits * expf, name='policy_explore')
is_training = get_current_tower_context().is_training
if not is_training:
return
log_probs = tf.log(self.logits + 1e-6)
log_probs = tf.log(self.policy + 1e-6)
log_pi_a_given_s = tf.reduce_sum(
log_probs * tf.one_hot(action, NUM_ACTIONS), 1)
advantage = tf.subtract(tf.stop_gradient(self.value), futurereward, name='advantage')
policy_loss = tf.reduce_sum(log_pi_a_given_s * advantage, name='policy_loss')
xentropy_loss = tf.reduce_sum(
self.logits * log_probs, name='xentropy_loss')
self.policy * log_probs, name='xentropy_loss')
value_loss = tf.nn.l2_loss(self.value - futurereward, name='value_loss')
pred_reward = tf.reduce_mean(self.value, name='predict_reward')
......@@ -151,7 +151,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _setup_graph(self):
self.async_predictor = MultiThreadAsyncPredictor(
self.trainer.get_predictors(['state'], ['logitsT', 'pred_value'],
self.trainer.get_predictors(['state'], ['policy_explore', 'pred_value'],
PREDICTOR_THREAD), batch_size=15)
def _before_train(self):
......@@ -220,7 +220,7 @@ def get_config():
[(80, 2), (100, 3), (120, 4), (140, 5)]),
master,
StartProcOrThread(master),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['policy']), 2),
],
session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)),
......@@ -254,7 +254,7 @@ if __name__ == '__main__':
model=Model(),
session_init=SaverRestore(args.load),
input_names=['state'],
output_names=['logits'])
output_names=['policy'])
if args.task == 'play':
play_model(cfg)
elif args.task == 'eval':
......
......@@ -43,15 +43,15 @@ class GANModelDesc(ModelDesc):
d_pos_acc = tf.reduce_mean(tf.cast(score_real > 0.5, tf.float32), name='accuracy_real')
d_neg_acc = tf.reduce_mean(tf.cast(score_fake < 0.5, tf.float32), name='accuracy_fake')
self.d_accuracy = tf.add(.5 * d_pos_acc, .5 * d_neg_acc, name='accuracy')
d_accuracy = tf.add(.5 * d_pos_acc, .5 * d_neg_acc, name='accuracy')
self.d_loss = tf.add(.5 * d_loss_pos, .5 * d_loss_neg, name='loss')
with tf.name_scope("gen"):
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_fake, labels=tf.ones_like(logits_fake)), name='loss')
self.g_accuracy = tf.reduce_mean(tf.cast(score_fake > 0.5, tf.float32), name='accuracy')
g_accuracy = tf.reduce_mean(tf.cast(score_fake > 0.5, tf.float32), name='accuracy')
add_moving_summary(self.g_loss, self.d_loss, self.d_accuracy, self.g_accuracy)
add_moving_summary(self.g_loss, self.d_loss, d_accuracy, g_accuracy)
class GANTrainer(FeedfreeTrainerBase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment