Commit 4587944d authored by Yuxin Wu's avatar Yuxin Wu

small fix

parent fbf93d44
...@@ -15,7 +15,6 @@ import six ...@@ -15,7 +15,6 @@ import six
from six.moves import queue from six.moves import queue
from tensorpack import * from tensorpack import *
from tensorpack.utils import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.utils.serialize import * from tensorpack.utils.serialize import *
from tensorpack.utils.timer import * from tensorpack.utils.timer import *
...@@ -68,10 +67,10 @@ class Model(ModelDesc): ...@@ -68,10 +67,10 @@ class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
assert NUM_ACTIONS is not None assert NUM_ACTIONS is not None
return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'), return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
InputVar(tf.int32, (None,), 'action'), InputVar(tf.int64, (None,), 'action'),
InputVar(tf.float32, (None,), 'futurereward') ] InputVar(tf.float32, (None,), 'futurereward') ]
def _get_NN_prediction(self, image, is_training): def _get_NN_prediction(self, image):
image = image / 255.0 image = image / 255.0
with argscope(Conv2D, nl=tf.nn.relu): with argscope(Conv2D, nl=tf.nn.relu):
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5) l = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
...@@ -88,21 +87,22 @@ class Model(ModelDesc): ...@@ -88,21 +87,22 @@ class Model(ModelDesc):
value = FullyConnected('fc-v', l, 1, nl=tf.identity) value = FullyConnected('fc-v', l, 1, nl=tf.identity)
return policy, value return policy, value
def _build_graph(self, inputs, is_training): def _build_graph(self, inputs):
state, action, futurereward = inputs state, action, futurereward = inputs
policy, self.value = self._get_NN_prediction(state, is_training) policy, self.value = self._get_NN_prediction(state)
self.value = tf.squeeze(self.value, [1], name='pred_value') # (B,) self.value = tf.squeeze(self.value, [1], name='pred_value') # (B,)
self.logits = tf.nn.softmax(policy, name='logits') self.logits = tf.nn.softmax(policy, name='logits')
expf = tf.get_variable('explore_factor', shape=[], expf = tf.get_variable('explore_factor', shape=[],
initializer=tf.constant_initializer(1), trainable=False) initializer=tf.constant_initializer(1), trainable=False)
logitsT = tf.nn.softmax(policy * expf, name='logitsT') logitsT = tf.nn.softmax(policy * expf, name='logitsT')
is_training = get_current_tower_context().is_training
if not is_training: if not is_training:
return return
log_probs = tf.log(self.logits + 1e-6) log_probs = tf.log(self.logits + 1e-6)
log_pi_a_given_s = tf.reduce_sum( log_pi_a_given_s = tf.reduce_sum(
log_probs * tf.one_hot(tf.cast(action,tf.int64), NUM_ACTIONS, 1.0, 0.0), 1) log_probs * tf.one_hot(action, NUM_ACTIONS), 1)
advantage = tf.sub(tf.stop_gradient(self.value), futurereward, name='advantage') advantage = tf.sub(tf.stop_gradient(self.value), futurereward, name='advantage')
policy_loss = tf.reduce_sum(log_pi_a_given_s * advantage, name='policy_loss') policy_loss = tf.reduce_sum(log_pi_a_given_s * advantage, name='policy_loss')
xentropy_loss = tf.reduce_sum( xentropy_loss = tf.reduce_sum(
...@@ -110,7 +110,7 @@ class Model(ModelDesc): ...@@ -110,7 +110,7 @@ class Model(ModelDesc):
value_loss = tf.nn.l2_loss(self.value - futurereward, name='value_loss') value_loss = tf.nn.l2_loss(self.value - futurereward, name='value_loss')
pred_reward = tf.reduce_mean(self.value, name='predict_reward') pred_reward = tf.reduce_mean(self.value, name='predict_reward')
advantage = tf.sqrt(tf.reduce_mean(tf.square(advantage)), name='rms_advantage') advantage = symbf.rms(advantage, name='rms_advantage')
summary.add_moving_summary(policy_loss, xentropy_loss, value_loss, pred_reward, advantage) summary.add_moving_summary(policy_loss, xentropy_loss, value_loss, pred_reward, advantage)
entropy_beta = tf.get_variable('entropy_beta', shape=[], entropy_beta = tf.get_variable('entropy_beta', shape=[],
initializer=tf.constant_initializer(0.01), trainable=False) initializer=tf.constant_initializer(0.01), trainable=False)
...@@ -139,7 +139,7 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -139,7 +139,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _on_state(self, state, ident): def _on_state(self, state, ident):
def cb(outputs): def cb(outputs):
distrib, value = outputs.result() distrib, value = outputs.result()
assert np.all(np.isfinite(distrib)) assert np.all(np.isfinite(distrib)), distrib
action = np.random.choice(len(distrib), p=distrib) action = np.random.choice(len(distrib), p=distrib)
client = self.clients[ident] client = self.clients[ident]
client.memory.append(TransitionExperience(state, action, None, value=value)) client.memory.append(TransitionExperience(state, action, None, value=value))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment