Commit d493c30b authored by Yuxin Wu's avatar Yuxin Wu

Add importance sampling to A3C

parent fe323b6b
...@@ -32,10 +32,10 @@ multiprocess Python program to get a cgroup dedicated for the task. ...@@ -32,10 +32,10 @@ multiprocess Python program to get a cgroup dedicated for the task.
Download models from [model zoo](https://goo.gl/9yIol2). Download models from [model zoo](https://goo.gl/9yIol2).
Watch the agent play: Watch the agent play:
`./train-atari.py --task play --env Breakout-v0 --load Breakout-v0.tfmodel` `./train-atari.py --task play --env Breakout-v0 --load Breakout-v0.npy`
Generate gym submissions: Generate gym submissions:
`./train-atari.py --task gen_submit --load Breakout-v0.tfmodel --env Breakout-v0 --output output_dir` `./train-atari.py --task gen_submit --load Breakout-v0.npy --env Breakout-v0 --output output_dir`
Models are available for the following atari environments (click to watch videos of my agent): Models are available for the following atari environments (click to watch videos of my agent):
...@@ -43,23 +43,23 @@ Models are available for the following atari environments (click to watch videos ...@@ -43,23 +43,23 @@ Models are available for the following atari environments (click to watch videos
| - | - | - | - | | - | - | - | - |
| [AirRaid](https://gym.openai.com/evaluations/eval_zIeNk5MxSGOmvGEUxrZDUw) | [Alien](https://gym.openai.com/evaluations/eval_8NR1IvjTQkSIT6En4xSMA) | [Amidar](https://gym.openai.com/evaluations/eval_HwEazbHtTYGpCialv9uPhA) | [Assault](https://gym.openai.com/evaluations/eval_tCiHwy5QrSdFVucSbBV6Q) | | [AirRaid](https://gym.openai.com/evaluations/eval_zIeNk5MxSGOmvGEUxrZDUw) | [Alien](https://gym.openai.com/evaluations/eval_8NR1IvjTQkSIT6En4xSMA) | [Amidar](https://gym.openai.com/evaluations/eval_HwEazbHtTYGpCialv9uPhA) | [Assault](https://gym.openai.com/evaluations/eval_tCiHwy5QrSdFVucSbBV6Q) |
| [Asterix](https://gym.openai.com/evaluations/eval_mees2c58QfKm5GspCjRfCA) | [Asteroids](https://gym.openai.com/evaluations/eval_8eHKsRL4RzuZEq9AOLZA) | [Atlantis](https://gym.openai.com/evaluations/eval_Z1B3d7A1QCaQk1HpO1Rg) | [BankHeist](https://gym.openai.com/evaluations/eval_hifoaxFTIuLlPd38BjnOw) | | [Asterix](https://gym.openai.com/evaluations/eval_mees2c58QfKm5GspCjRfCA) | [Asteroids](https://gym.openai.com/evaluations/eval_8eHKsRL4RzuZEq9AOLZA) | [Atlantis](https://gym.openai.com/evaluations/eval_Z1B3d7A1QCaQk1HpO1Rg) | [BankHeist](https://gym.openai.com/evaluations/eval_hifoaxFTIuLlPd38BjnOw) |
| [BattleZone](https://gym.openai.com/evaluations/eval_SoLit2bR1qmFoC0AsJF6Q) | [BeamRider](https://gym.openai.com/evaluations/eval_KuOYumrjQjixwL0spG0iCA) | [Berzerk](https://gym.openai.com/evaluations/eval_Yri0XQbwRy62NzWILdn5IA) | [Breakout](https://gym.openai.com/evaluations/eval_L55gczPrQJamMGihq9tzA) | | [BattleZone](https://gym.openai.com/evaluations/eval_SoLit2bR1qmFoC0AsJF6Q) | [BeamRider](https://gym.openai.com/evaluations/eval_KuOYumrjQjixwL0spG0iCA) | [Berzerk](https://gym.openai.com/evaluations/eval_Yri0XQbwRy62NzWILdn5IA) | [Breakout](https://gym.openai.com/evaluations/eval_NiKaIN4NSUeEIvWqIgVDrA) |
| [Carnival](https://gym.openai.com/evaluations/eval_xJSOlo2lSWaH1wHEOX5vw) | [Centipede](https://gym.openai.com/evaluations/eval_mc1Kp5e6R42rFdjeMLzkIg) | [ChopperCommand](https://gym.openai.com/evaluations/eval_tYVKyh7wQieRIKgEvVaCuw) | [CrazyClimber](https://gym.openai.com/evaluations/eval_bKeBg0QwSgOm6A0I0wDhSw) | | [Carnival](https://gym.openai.com/evaluations/eval_xJSOlo2lSWaH1wHEOX5vw) | [Centipede](https://gym.openai.com/evaluations/eval_mc1Kp5e6R42rFdjeMLzkIg) | [ChopperCommand](https://gym.openai.com/evaluations/eval_tYVKyh7wQieRIKgEvVaCuw) | [CrazyClimber](https://gym.openai.com/evaluations/eval_bKeBg0QwSgOm6A0I0wDhSw) |
| [DemonAttack](https://gym.openai.com/evaluations/eval_tt21vVaRCKYzWFcg1Kw) | [DoubleDunk](https://gym.openai.com/evaluations/eval_FI1GpF4TlCuf29KccTpQ) | [ElevatorAction](https://gym.openai.com/evaluations/eval_SqeAouMvR0icRivx2xprZg) | [FishingDerby](https://gym.openai.com/evaluations/eval_pPLCnFXsTVaayrIboDOs0g) | | [DemonAttack](https://gym.openai.com/evaluations/eval_tt21vVaRCKYzWFcg1Kw) | [DoubleDunk](https://gym.openai.com/evaluations/eval_FI1GpF4TlCuf29KccTpQ) | [ElevatorAction](https://gym.openai.com/evaluations/eval_SqeAouMvR0icRivx2xprZg) | [FishingDerby](https://gym.openai.com/evaluations/eval_pPLCnFXsTVaayrIboDOs0g) |
| [Frostbite](https://gym.openai.com/evaluations/eval_qtC3taKFSgWwkO9q9IM4hA) | [Gopher](https://gym.openai.com/evaluations/eval_KVcpR1YgQkEzrL2VIcAQ) | [Gravitar](https://gym.openai.com/evaluations/eval_QudrLdVmTpK9HF5juaZr0w) | [IceHockey](https://gym.openai.com/evaluations/eval_8oWCTwwGS7OUTTGRwBPQkQ) | | [Frostbite](https://gym.openai.com/evaluations/eval_qtC3taKFSgWwkO9q9IM4hA) | [Gopher](https://gym.openai.com/evaluations/eval_KVcpR1YgQkEzrL2VIcAQ) | [Gravitar](https://gym.openai.com/evaluations/eval_QudrLdVmTpK9HF5juaZr0w) | [IceHockey](https://gym.openai.com/evaluations/eval_8oWCTwwGS7OUTTGRwBPQkQ) |
| [Jamesbond](https://gym.openai.com/evaluations/eval_mLF7XPi8Tw66pnjP73JsmA) | [JourneyEscape](https://gym.openai.com/evaluations/eval_S9nQuXLRSu7S5x21Ay6AA) | [Kangaroo](https://gym.openai.com/evaluations/eval_TNJiLB8fTqOPfvINnPXoQ) | [Krull](https://gym.openai.com/evaluations/eval_dfOS2WzhTh6sn1FuPS9HA) | | [Jamesbond](https://gym.openai.com/evaluations/eval_mLF7XPi8Tw66pnjP73JsmA) | [JourneyEscape](https://gym.openai.com/evaluations/eval_S9nQuXLRSu7S5x21Ay6AA) | [Kangaroo](https://gym.openai.com/evaluations/eval_TNJiLB8fTqOPfvINnPXoQ) | [Krull](https://gym.openai.com/evaluations/eval_dfOS2WzhTh6sn1FuPS9HA) |
| [KungFuMaster](https://gym.openai.com/evaluations/eval_vNWDShYTRC0MhfIybeUYg) | [MsPacman](https://gym.openai.com/evaluations/eval_kpL9bSsS4GXsYb9HuEfew) | [NameThisGame](https://gym.openai.com/evaluations/eval_LZqfv706SdOMtR4ZZIwIsg) | [Phoenix](https://gym.openai.com/evaluations/eval_uzUruiB3RRKUMvJIxvEzYA) | | [KungFuMaster](https://gym.openai.com/evaluations/eval_vNWDShYTRC0MhfIybeUYg) | [MsPacman](https://gym.openai.com/evaluations/eval_kpL9bSsS4GXsYb9HuEfew) | [NameThisGame](https://gym.openai.com/evaluations/eval_LZqfv706SdOMtR4ZZIwIsg) | [Phoenix](https://gym.openai.com/evaluations/eval_uzUruiB3RRKUMvJIxvEzYA) |
| [Pong](https://gym.openai.com/evaluations/eval_8L7SV59nSW6GGbbP3N4G6w) | [Pooyan](https://gym.openai.com/evaluations/eval_UXFVI34MSAuNTtjZcK8N0A) | [Qbert](https://gym.openai.com/evaluations/eval_wekCJkrWQm9NrOUzltXg) | [Riverraid](https://gym.openai.com/evaluations/eval_OU4x3DkTfm4uaXy6CIaXg) | | [Pong](https://gym.openai.com/evaluations/eval_8L7SV59nSW6GGbbP3N4G6w) | [Pooyan](https://gym.openai.com/evaluations/eval_UXFVI34MSAuNTtjZcK8N0A) | [Qbert](https://gym.openai.com/evaluations/eval_wekCJkrWQm9NrOUzltXg) | [Riverraid](https://gym.openai.com/evaluations/eval_OU4x3DkTfm4uaXy6CIaXg) |
| [RoadRunner](https://gym.openai.com/evaluations/eval_wINKQTwxT9ipydHOXBhg) | [Robotank](https://gym.openai.com/evaluations/eval_Gr5c0ld3QACLDPQrGdzbiw) | [Seaquest](https://gym.openai.com/evaluations/eval_N2624y3NSJWrOgoMSpOi4w) | [SpaceInvaders](https://gym.openai.com/evaluations/eval_Eduozx4HRyqgTCVk9ltw) | | [RoadRunner](https://gym.openai.com/evaluations/eval_wINKQTwxT9ipydHOXBhg) | [Robotank](https://gym.openai.com/evaluations/eval_Gr5c0ld3QACLDPQrGdzbiw) | [Seaquest](https://gym.openai.com/evaluations/eval_pjjgc9POQJK4IuVw8nXlBw) | [SpaceInvaders](https://gym.openai.com/evaluations/eval_Eduozx4HRyqgTCVk9ltw) |
| [StarGunner](https://gym.openai.com/evaluations/eval_JB5cOJXFSS2cTQ7dXK8Iag) | [Tennis](https://gym.openai.com/evaluations/eval_gDjJD0MMS1yLm1T0hdqI4g) | [Tutankham](https://gym.openai.com/evaluations/eval_gDjJD0MMS1yLm1T0hdqI4g) | [UpNDown](https://gym.openai.com/evaluations/eval_KmkvMJkxQFSED20wFUMdIA) | | [StarGunner](https://gym.openai.com/evaluations/eval_JB5cOJXFSS2cTQ7dXK8Iag) | [Tennis](https://gym.openai.com/evaluations/eval_gDjJD0MMS1yLm1T0hdqI4g) | [Tutankham](https://gym.openai.com/evaluations/eval_gDjJD0MMS1yLm1T0hdqI4g) | [UpNDown](https://gym.openai.com/evaluations/eval_KmkvMJkxQFSED20wFUMdIA) |
| [VideoPinball](https://gym.openai.com/evaluations/eval_PWwzNhVFR2CxjYvEsPfT1g) | [WizardOfWor](https://gym.openai.com/evaluations/eval_1oGQhphpQhmzEMIYRrrp0A) | [Zaxxon](https://gym.openai.com/evaluations/eval_TIQ102EwTrHrOyve2RGfg) | | | [VideoPinball](https://gym.openai.com/evaluations/eval_PWwzNhVFR2CxjYvEsPfT1g) | [WizardOfWor](https://gym.openai.com/evaluations/eval_1oGQhphpQhmzEMIYRrrp0A) | [Zaxxon](https://gym.openai.com/evaluations/eval_TIQ102EwTrHrOyve2RGfg) | |
Note that atari game settings in gym are quite different from DeepMind papers, so the scores are not comparable. The most notable differences are: Note that atari game settings in gym (AtariGames-v0) are quite different from DeepMind papers, so the scores are not comparable. The most notable differences are:
+ In gym, each action is randomly repeated 2~4 times. + Each action is randomly repeated 2~4 times.
+ In gym, inputs are RGB instead of greyscale. + Inputs are RGB instead of greyscale.
+ In gym, an episode is limited to 10000 steps. + An episode is limited to 10000 steps.
+ The action space also seems to be different. + Lost of live is not end of episode.
Also see the DQN implementation [here](../DeepQNetwork) Also see the DQN implementation [here](../DeepQNetwork)
...@@ -64,7 +64,7 @@ def get_player(viz=False, train=False, dumpdir=None): ...@@ -64,7 +64,7 @@ def get_player(viz=False, train=False, dumpdir=None):
if not train: if not train:
pl = PreventStuckPlayer(pl, 30, 1) pl = PreventStuckPlayer(pl, 30, 1)
else: else:
pl = LimitLengthPlayer(pl, 40000) pl = LimitLengthPlayer(pl, 60000)
return pl return pl
...@@ -78,7 +78,9 @@ class Model(ModelDesc): ...@@ -78,7 +78,9 @@ class Model(ModelDesc):
assert NUM_ACTIONS is not None assert NUM_ACTIONS is not None
return [InputDesc(tf.uint8, (None,) + IMAGE_SHAPE3, 'state'), return [InputDesc(tf.uint8, (None,) + IMAGE_SHAPE3, 'state'),
InputDesc(tf.int64, (None,), 'action'), InputDesc(tf.int64, (None,), 'action'),
InputDesc(tf.float32, (None,), 'futurereward')] InputDesc(tf.float32, (None,), 'futurereward'),
InputDesc(tf.float32, (None,), 'action_prob'),
]
def _get_NN_prediction(self, image): def _get_NN_prediction(self, image):
image = tf.cast(image, tf.float32) / 255.0 image = tf.cast(image, tf.float32) / 255.0
...@@ -98,14 +100,10 @@ class Model(ModelDesc): ...@@ -98,14 +100,10 @@ class Model(ModelDesc):
return logits, value return logits, value
def _build_graph(self, inputs): def _build_graph(self, inputs):
state, action, futurereward = inputs state, action, futurereward, action_prob = inputs
logits, self.value = self._get_NN_prediction(state) logits, self.value = self._get_NN_prediction(state)
self.value = tf.squeeze(self.value, [1], name='pred_value') # (B,) self.value = tf.squeeze(self.value, [1], name='pred_value') # (B,)
self.policy = tf.nn.softmax(logits, name='policy') self.policy = tf.nn.softmax(logits, name='policy')
expf = tf.get_variable('explore_factor', shape=[],
initializer=tf.constant_initializer(1), trainable=False)
policy_explore = tf.nn.softmax(logits * expf, name='policy_explore')
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
if not is_training: if not is_training:
return return
...@@ -114,7 +112,11 @@ class Model(ModelDesc): ...@@ -114,7 +112,11 @@ class Model(ModelDesc):
log_pi_a_given_s = tf.reduce_sum( log_pi_a_given_s = tf.reduce_sum(
log_probs * tf.one_hot(action, NUM_ACTIONS), 1) log_probs * tf.one_hot(action, NUM_ACTIONS), 1)
advantage = tf.subtract(tf.stop_gradient(self.value), futurereward, name='advantage') advantage = tf.subtract(tf.stop_gradient(self.value), futurereward, name='advantage')
policy_loss = tf.reduce_sum(log_pi_a_given_s * advantage, name='policy_loss')
pi_a_given_s = tf.reduce_sum(self.policy * tf.one_hot(action, NUM_ACTIONS), 1) # (B,)
importance = tf.stop_gradient(tf.clip_by_value(pi_a_given_s / (action_prob + 1e-8), 0, 10))
policy_loss = tf.reduce_sum(log_pi_a_given_s * advantage * importance, name='policy_loss')
xentropy_loss = tf.reduce_sum( xentropy_loss = tf.reduce_sum(
self.policy * log_probs, name='xentropy_loss') self.policy * log_probs, name='xentropy_loss')
value_loss = tf.nn.l2_loss(self.value - futurereward, name='value_loss') value_loss = tf.nn.l2_loss(self.value - futurereward, name='value_loss')
...@@ -128,7 +130,8 @@ class Model(ModelDesc): ...@@ -128,7 +130,8 @@ class Model(ModelDesc):
tf.cast(tf.shape(futurereward)[0], tf.float32), tf.cast(tf.shape(futurereward)[0], tf.float32),
name='cost') name='cost')
summary.add_moving_summary(policy_loss, xentropy_loss, summary.add_moving_summary(policy_loss, xentropy_loss,
value_loss, pred_reward, advantage, self.cost) value_loss, pred_reward, advantage,
self.cost, tf.reduce_mean(importance, name='importance'))
def _get_optimizer(self): def _get_optimizer(self):
lr = symbf.get_scalar_var('learning_rate', 0.001, summary=True) lr = symbf.get_scalar_var('learning_rate', 0.001, summary=True)
...@@ -148,7 +151,7 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -148,7 +151,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _setup_graph(self): def _setup_graph(self):
self.async_predictor = MultiThreadAsyncPredictor( self.async_predictor = MultiThreadAsyncPredictor(
self.trainer.get_predictors(['state'], ['policy_explore', 'pred_value'], self.trainer.get_predictors(['state'], ['policy', 'pred_value'],
PREDICTOR_THREAD), batch_size=PREDICT_BATCH_SIZE) PREDICTOR_THREAD), batch_size=PREDICT_BATCH_SIZE)
def _before_train(self): def _before_train(self):
...@@ -164,7 +167,8 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -164,7 +167,8 @@ class MySimulatorMaster(SimulatorMaster, Callback):
assert np.all(np.isfinite(distrib)), distrib assert np.all(np.isfinite(distrib)), distrib
action = np.random.choice(len(distrib), p=distrib) action = np.random.choice(len(distrib), p=distrib)
client = self.clients[ident] client = self.clients[ident]
client.memory.append(TransitionExperience(state, action, None, value=value)) client.memory.append(TransitionExperience(
state, action, reward=None, value=value, prob=distrib[action]))
self.send_queue.put([ident, dumps(action)]) self.send_queue.put([ident, dumps(action)])
self.async_predictor.put_task([state], cb) self.async_predictor.put_task([state], cb)
...@@ -188,7 +192,7 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -188,7 +192,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
R = float(init_r) R = float(init_r)
for idx, k in enumerate(mem): for idx, k in enumerate(mem):
R = np.clip(k.reward, -1, 1) + GAMMA * R R = np.clip(k.reward, -1, 1) + GAMMA * R
self.queue.put([k.state, k.action, R]) self.queue.put([k.state, k.action, R, k.prob])
if not isOver: if not isOver:
client.memory = [last] client.memory = [last]
...@@ -216,8 +220,6 @@ def get_config(): ...@@ -216,8 +220,6 @@ def get_config():
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]), ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]), ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
ScheduledHyperParamSetter('explore_factor',
[(80, 2), (100, 3), (120, 4), (140, 5)]),
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
HumanHyperParamSetter('entropy_beta'), HumanHyperParamSetter('entropy_beta'),
master, master,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment