Commit 55098813 authored by Yuxin Wu's avatar Yuxin Wu

Upgrade gym

parent 4bc0c748
......@@ -362,7 +362,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
# include_init_with_doc doesn't work well for decorated init
# https://github.com/sphinx-doc/sphinx/issues/4258
return False
# hide deprecated stuff
# Hide some names that are deprecated or not intended to be used
if name in [
# deprecated stuff:
'GaussianDeform',
......
### A3C code and models for Atari games in gym
Multi-GPU version of the A3C algorithm in
[Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783),
with <500 lines of code.
[Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783).
Results of the same code trained on 47 different Atari games were uploaded on OpenAI Gym.
You can see them in [my gym page](https://gym.openai.com/users/ppwwyyxx).
Most of them are the best reproducible results on gym.
Results of the same code trained on 47 different Atari games were uploaded to OpenAI Gym.
Most of them were the best reproducible results on gym.
However OpenAI has later completely removed leaderboard from their site.
### To train on an Atari game:
......@@ -17,10 +16,10 @@ The speed is about 6~10 iterations/s on 1 GPU plus 12+ CPU cores.
With 2 TitanX + 20+ CPU cores, by setting `SIMULATOR_PROC=240, PREDICT_BATCH_SIZE=30, PREDICTOR_THREAD_PER_GPU=6`, it can improve to 16 it/s (2K images/s).
Note that the network architecture is larger than what's used in the original paper.
The uploaded models are all trained with 4 GPUs for about 2 days.
The pretrained models are all trained with 4 GPUs for about 2 days.
But on simple games like Breakout, you can get good performance within several hours.
Also note that multi-GPU doesn't give you obvious speedup here,
because the bottleneck in this implementation is not computation but data.
because the bottleneck in this implementation is not computation but simulation.
Some practicical notes:
......@@ -36,28 +35,32 @@ Download models from [model zoo](http://models.tensorpack.com/OpenAIGym/).
Watch the agent play:
`./train-atari.py --task play --env Breakout-v0 --load Breakout-v0.npz`
Generate gym submissions:
`./train-atari.py --task gen_submit --load Breakout-v0.npz --env Breakout-v0 --output output_dir`
Dump some videos:
`./train-atari.py --task dump_video --load Breakout-v0.npz --env Breakout-v0 --output output_dir --episode 3`
Models are available for the following atari environments (click to watch videos of my agent):
This table lists available pretrained models and scores (average over 100 episodes),
with their submission links.
The site is not maintained any more so the links might become invalid any time.
| | | | |
| - | - | - | - |
| [AirRaid](https://gym.openai.com/evaluations/eval_zIeNk5MxSGOmvGEUxrZDUw) | [Alien](https://gym.openai.com/evaluations/eval_8NR1IvjTQkSIT6En4xSMA) | [Amidar](https://gym.openai.com/evaluations/eval_HwEazbHtTYGpCialv9uPhA) | [Assault](https://gym.openai.com/evaluations/eval_tCiHwy5QrSdFVucSbBV6Q) |
| [Asterix](https://gym.openai.com/evaluations/eval_mees2c58QfKm5GspCjRfCA) | [Asteroids](https://gym.openai.com/evaluations/eval_8eHKsRL4RzuZEq9AOLZA) | [Atlantis](https://gym.openai.com/evaluations/eval_Z1B3d7A1QCaQk1HpO1Rg) | [BankHeist](https://gym.openai.com/evaluations/eval_hifoaxFTIuLlPd38BjnOw) |
| [BattleZone](https://gym.openai.com/evaluations/eval_SoLit2bR1qmFoC0AsJF6Q) | [BeamRider](https://gym.openai.com/evaluations/eval_KuOYumrjQjixwL0spG0iCA) | [Berzerk](https://gym.openai.com/evaluations/eval_Yri0XQbwRy62NzWILdn5IA) | [Breakout](https://gym.openai.com/evaluations/eval_NiKaIN4NSUeEIvWqIgVDrA) |
| [Carnival](https://gym.openai.com/evaluations/eval_xJSOlo2lSWaH1wHEOX5vw) | [Centipede](https://gym.openai.com/evaluations/eval_mc1Kp5e6R42rFdjeMLzkIg) | [ChopperCommand](https://gym.openai.com/evaluations/eval_tYVKyh7wQieRIKgEvVaCuw) | [CrazyClimber](https://gym.openai.com/evaluations/eval_bKeBg0QwSgOm6A0I0wDhSw) |
| [DemonAttack](https://gym.openai.com/evaluations/eval_tt21vVaRCKYzWFcg1Kw) | [DoubleDunk](https://gym.openai.com/evaluations/eval_FI1GpF4TlCuf29KccTpQ) | [ElevatorAction](https://gym.openai.com/evaluations/eval_SqeAouMvR0icRivx2xprZg) | [FishingDerby](https://gym.openai.com/evaluations/eval_pPLCnFXsTVaayrIboDOs0g) |
| [Frostbite](https://gym.openai.com/evaluations/eval_qtC3taKFSgWwkO9q9IM4hA) | [Gopher](https://gym.openai.com/evaluations/eval_KVcpR1YgQkEzrL2VIcAQ) | [Gravitar](https://gym.openai.com/evaluations/eval_QudrLdVmTpK9HF5juaZr0w) | [IceHockey](https://gym.openai.com/evaluations/eval_8oWCTwwGS7OUTTGRwBPQkQ) |
| [Jamesbond](https://gym.openai.com/evaluations/eval_mLF7XPi8Tw66pnjP73JsmA) | [JourneyEscape](https://gym.openai.com/evaluations/eval_S9nQuXLRSu7S5x21Ay6AA) | [Kangaroo](https://gym.openai.com/evaluations/eval_TNJiLB8fTqOPfvINnPXoQ) | [Krull](https://gym.openai.com/evaluations/eval_dfOS2WzhTh6sn1FuPS9HA) |
| [KungFuMaster](https://gym.openai.com/evaluations/eval_vNWDShYTRC0MhfIybeUYg) | [MsPacman](https://gym.openai.com/evaluations/eval_kpL9bSsS4GXsYb9HuEfew) | [NameThisGame](https://gym.openai.com/evaluations/eval_LZqfv706SdOMtR4ZZIwIsg) | [Phoenix](https://gym.openai.com/evaluations/eval_uzUruiB3RRKUMvJIxvEzYA) |
| [Pong](https://gym.openai.com/evaluations/eval_8L7SV59nSW6GGbbP3N4G6w) | [Pooyan](https://gym.openai.com/evaluations/eval_UXFVI34MSAuNTtjZcK8N0A) | [Qbert](https://gym.openai.com/evaluations/eval_S8XdrbByQ1eWLUD5jtQYIQ) | [Riverraid](https://gym.openai.com/evaluations/eval_OU4x3DkTfm4uaXy6CIaXg) |
| [RoadRunner](https://gym.openai.com/evaluations/eval_wINKQTwxT9ipydHOXBhg) | [Robotank](https://gym.openai.com/evaluations/eval_Gr5c0ld3QACLDPQrGdzbiw) | [Seaquest](https://gym.openai.com/evaluations/eval_pjjgc9POQJK4IuVw8nXlBw) | [SpaceInvaders](https://gym.openai.com/evaluations/eval_Eduozx4HRyqgTCVk9ltw) |
| [StarGunner](https://gym.openai.com/evaluations/eval_JB5cOJXFSS2cTQ7dXK8Iag) | [Tennis](https://gym.openai.com/evaluations/eval_gDjJD0MMS1yLm1T0hdqI4g) | [Tutankham](https://gym.openai.com/evaluations/eval_gDjJD0MMS1yLm1T0hdqI4g) | [UpNDown](https://gym.openai.com/evaluations/eval_KmkvMJkxQFSED20wFUMdIA) |
| [VideoPinball](https://gym.openai.com/evaluations/eval_PWwzNhVFR2CxjYvEsPfT1g) | [WizardOfWor](https://gym.openai.com/evaluations/eval_1oGQhphpQhmzEMIYRrrp0A) | [Zaxxon](https://gym.openai.com/evaluations/eval_TIQ102EwTrHrOyve2RGfg) | |
Note that atari game settings in gym (AtariGames-v0) are quite different from DeepMind papers, so the scores are not comparable. The most notable differences are:
| [AirRaid](https://gym.openai.com/evaluations/eval_zIeNk5MxSGOmvGEUxrZDUw)(2727) | [Alien](https://gym.openai.com/evaluations/eval_8NR1IvjTQkSIT6En4xSMA) (2611) | [Amidar](https://gym.openai.com/evaluations/eval_HwEazbHtTYGpCialv9uPhA)(1376) | [Assault](https://gym.openai.com/evaluations/eval_tCiHwy5QrSdFVucSbBV6Q)(3397) |
| [Asterix](https://gym.openai.com/evaluations/eval_mees2c58QfKm5GspCjRfCA)(407432) | [Asteroids](https://gym.openai.com/evaluations/eval_8eHKsRL4RzuZEq9AOLZA)(1965) | [Atlantis](https://gym.openai.com/evaluations/eval_Z1B3d7A1QCaQk1HpO1Rg)(217186) | [BankHeist](https://gym.openai.com/evaluations/eval_hifoaxFTIuLlPd38BjnOw)(1274) |
| [BattleZone](https://gym.openai.com/evaluations/eval_SoLit2bR1qmFoC0AsJF6Q)(29210) | [BeamRider](https://gym.openai.com/evaluations/eval_KuOYumrjQjixwL0spG0iCA)(5972) | [Berzerk](https://gym.openai.com/evaluations/eval_Yri0XQbwRy62NzWILdn5IA)(2289) | [Breakout](https://gym.openai.com/evaluations/eval_NiKaIN4NSUeEIvWqIgVDrA) (667) |
| [Carnival](https://gym.openai.com/evaluations/eval_xJSOlo2lSWaH1wHEOX5vw)(5211) | [Centipede](https://gym.openai.com/evaluations/eval_mc1Kp5e6R42rFdjeMLzkIg)(2909) | [ChopperCommand](https://gym.openai.com/evaluations/eval_tYVKyh7wQieRIKgEvVaCuw)(6031) | [CrazyClimber](https://gym.openai.com/evaluations/eval_bKeBg0QwSgOm6A0I0wDhSw)(105297) |
| [DemonAttack](https://gym.openai.com/evaluations/eval_tt21vVaRCKYzWFcg1Kw)(33992) | [DoubleDunk](https://gym.openai.com/evaluations/eval_FI1GpF4TlCuf29KccTpQ)(23) | [ElevatorAction](https://gym.openai.com/evaluations/eval_SqeAouMvR0icRivx2xprZg)(11377) | [FishingDerby](https://gym.openai.com/evaluations/eval_pPLCnFXsTVaayrIboDOs0g)(34) |
| [Frostbite](https://gym.openai.com/evaluations/eval_qtC3taKFSgWwkO9q9IM4hA)(6824) | [Gopher](https://gym.openai.com/evaluations/eval_KVcpR1YgQkEzrL2VIcAQ)(22595) | [Gravitar](https://gym.openai.com/evaluations/eval_QudrLdVmTpK9HF5juaZr0w)(2144) | [IceHockey](https://gym.openai.com/evaluations/eval_8oWCTwwGS7OUTTGRwBPQkQ)(19) |
| [Jamesbond](https://gym.openai.com/evaluations/eval_mLF7XPi8Tw66pnjP73JsmA)(640) | [JourneyEscape](https://gym.openai.com/evaluations/eval_S9nQuXLRSu7S5x21Ay6AA)(-407) | [Kangaroo](https://gym.openai.com/evaluations/eval_TNJiLB8fTqOPfvINnPXoQ)(6540) | [Krull](https://gym.openai.com/evaluations/eval_dfOS2WzhTh6sn1FuPS9HA)(6100) |
| [KungFuMaster](https://gym.openai.com/evaluations/eval_vNWDShYTRC0MhfIybeUYg)(34767) | [MsPacman](https://gym.openai.com/evaluations/eval_kpL9bSsS4GXsYb9HuEfew)(5738) | [NameThisGame](https://gym.openai.com/evaluations/eval_LZqfv706SdOMtR4ZZIwIsg)(15321) | [Phoenix](https://gym.openai.com/evaluations/eval_uzUruiB3RRKUMvJIxvEzYA)(75312) |
| [Pong](https://gym.openai.com/evaluations/eval_8L7SV59nSW6GGbbP3N4G6w)(21) | [Pooyan](https://gym.openai.com/evaluations/eval_UXFVI34MSAuNTtjZcK8N0A)(5607) | [Qbert](https://gym.openai.com/evaluations/eval_S8XdrbByQ1eWLUD5jtQYIQ)(20182) | [Riverraid](https://gym.openai.com/evaluations/eval_OU4x3DkTfm4uaXy6CIaXg)(14185) |
| [RoadRunner](https://gym.openai.com/evaluations/eval_wINKQTwxT9ipydHOXBhg)(60615) | [Robotank](https://gym.openai.com/evaluations/eval_Gr5c0ld3QACLDPQrGdzbiw)(60) | [Seaquest](https://gym.openai.com/evaluations/eval_pjjgc9POQJK4IuVw8nXlBw)(46890) | SpaceInvaders(3454) |
| [StarGunner](https://gym.openai.com/evaluations/eval_JB5cOJXFSS2cTQ7dXK8Iag)(93480) | [Tennis](https://gym.openai.com/evaluations/eval_gDjJD0MMS1yLm1T0hdqI4g)(23) | Tutankham(275) | [UpNDown](https://gym.openai.com/evaluations/eval_KmkvMJkxQFSED20wFUMdIA)(92163) |
| [VideoPinball](https://gym.openai.com/evaluations/eval_PWwzNhVFR2CxjYvEsPfT1g)(140156) | [WizardOfWor](https://gym.openai.com/evaluations/eval_1oGQhphpQhmzEMIYRrrp0A)(3824) | [Zaxxon](https://gym.openai.com/evaluations/eval_TIQ102EwTrHrOyve2RGfg)(32894) | |
All models above are trained with the `-v0` variant of atari games.
Note that this variant is quite different from DeepMind papers, so the scores are not directly comparable.
The most notable differences are:
+ Each action is randomly repeated 2~4 times.
+ Inputs are RGB instead of greyscale.
+ An episode is limited to 10000 steps.
......
......@@ -54,7 +54,7 @@ ENV_NAME = None
def get_player(train=False, dumpdir=None):
env = gym.make(ENV_NAME)
if dumpdir:
env = gym.wrappers.Monitor(env, dumpdir)
env = gym.wrappers.Monitor(env, dumpdir, video_callable=lambda _: True)
env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
env = FrameStack(env, 4)
......@@ -272,7 +272,7 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model')
parser.add_argument('--env', help='env', required=True)
parser.add_argument('--task', help='task to perform',
choices=['play', 'eval', 'train', 'gen_submit'], default='train')
choices=['play', 'eval', 'train', 'dump_video'], default='train')
parser.add_argument('--output', help='output directory for submission', default='output_dir')
parser.add_argument('--episode', help='number of episode to eval', default=100, type=int)
args = parser.parse_args()
......@@ -297,10 +297,9 @@ if __name__ == '__main__':
args.episode, render=True)
elif args.task == 'eval':
eval_model_multithread(pred, args.episode, get_player)
elif args.task == 'gen_submit':
elif args.task == 'dump_video':
play_n_episodes(
get_player(train=False, dumpdir=args.output),
pred, args.episode)
# gym.upload(args.output, api_key='xxx')
else:
train()
......@@ -9,7 +9,7 @@ import tensorpack
from tensorpack import ModelDesc, InputDesc
from tensorpack.utils import logger
from tensorpack.tfutils import (
summary, get_current_tower_context, optimizer, gradproc)
varreplace, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
assert tensorpack.tfutils.common.get_tf_version_number() >= 1.2
......@@ -60,7 +60,7 @@ class Model(ModelDesc):
self.predict_value, 1), name='predict_reward')
summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'):
with tf.variable_scope('target'), varreplace.freeze_variables(skip_collection=True):
targetQ_predict_value = self.get_DQN_prediction(next_state) # NxA
if self.method != 'Double':
......@@ -96,6 +96,6 @@ class Model(ModelDesc):
target_name = v.op.name
if target_name.startswith('target'):
new_name = target_name.replace('target/', '')
logger.info("{} <- {}".format(target_name, new_name))
logger.info("Target Network Update: {} <- {}".format(target_name, new_name))
ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network')
......@@ -138,12 +138,12 @@ class AtariPlayer(gym.Env):
self.last_raw_screen = self._grab_raw_image()
self.ale.act(0)
def _reset(self):
def reset(self):
if self.ale.game_over():
self._restart_episode()
return self._current_state()
def _step(self, act):
def step(self, act):
oldlives = self.ale.lives()
r = 0
for k in range(self.frame_skip):
......
......@@ -8,6 +8,9 @@ from collections import deque
import gym
from gym import spaces
_v0, _v1 = gym.__version__.split('.')[:2]
assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__
"""
The following wrappers are copied or modified from openai/baselines:
......@@ -20,7 +23,7 @@ class MapState(gym.ObservationWrapper):
gym.ObservationWrapper.__init__(self, env)
self._func = map_func
def _observation(self, obs):
def observation(self, obs):
return self._func(obs)
......@@ -32,22 +35,23 @@ class FrameStack(gym.Wrapper):
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k))
self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], chan * k), dtype=np.uint8)
def _reset(self):
def reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob))
self.frames.append(ob)
return self._observation()
return self.observation()
def _step(self, action):
def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._observation(), reward, done, info
return self.observation(), reward, done, info
def _observation(self):
def observation(self):
assert len(self.frames) == self.k
if self.frames[-1].ndim == 2:
return np.stack(self.frames, axis=-1)
......@@ -62,7 +66,7 @@ class _FireResetEnv(gym.Wrapper):
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
......@@ -72,6 +76,9 @@ class _FireResetEnv(gym.Wrapper):
self.env.reset()
return obs
def step(self, action):
return self.env.step(action)
def FireResetEnv(env):
if isinstance(env, gym.Wrapper):
......@@ -88,7 +95,7 @@ class LimitLength(gym.Wrapper):
gym.Wrapper.__init__(self, env)
self.k = k
def _reset(self):
def reset(self):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
......@@ -96,7 +103,7 @@ class LimitLength(gym.Wrapper):
self.cnt = 0
return ob
def _step(self, action):
def step(self, action):
ob, r, done, info = self.env.step(action)
self.cnt += 1
if self.cnt == self.k:
......
......@@ -18,7 +18,7 @@ from tensorpack.utils.utils import get_tqdm_kwargs
def play_one_episode(env, func, render=False):
def predict(s):
"""
Map from observation to action, with 0.001 greedy.
Map from observation to action, with 0.01 greedy.
"""
act = func(s[None, :, :, :])[0][0].argmax()
if random.random() < 0.01:
......@@ -45,7 +45,7 @@ def play_n_episodes(player, predfunc, nr, render=False):
print("{}/{}, score={}".format(k, nr, score))
def eval_with_funcs(predictors, nr_eval, get_player_fn):
def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
"""
Args:
predictors ([PredictorBase])
......@@ -67,7 +67,6 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
while not self.stopped():
try:
score = play_one_episode(player, self.func)
# print("Score, ", score)
except RuntimeError:
return
self.queue_put_stoppable(self.q, score)
......@@ -80,17 +79,21 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
def fetch():
r = q.get()
stat.feed(r)
if verbose:
logger.info("Score: {}".format(r))
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
fetch()
logger.info("Waiting for all the workers to finish the last run...")
for k in threads:
k.stop()
for k in threads:
k.join()
while q.qsize():
r = q.get()
stat.feed(r)
fetch()
if stat.count > 0:
return (stat.average, stat.max)
......@@ -100,11 +103,13 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
def eval_model_multithread(pred, nr_eval, get_player_fn):
"""
Args:
pred (OfflinePredictor): state -> Qvalue
pred (OfflinePredictor): state -> [#action]
"""
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
with pred.sess.as_default():
mean, max = eval_with_funcs([pred] * NR_PROC, nr_eval, get_player_fn)
mean, max = eval_with_funcs(
[pred] * NR_PROC, nr_eval,
get_player_fn, verbose=True)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
......
......@@ -95,6 +95,11 @@ class ModelDescBase(object):
Args:
args ([tf.Tensor]): tensors that matches the list of
:class:`InputDesc` defined by ``_get_inputs``.
Returns:
In general it returns nothing, but a subclass (e.g.
:class:`ModelDesc` may require it to return necessary information
to build the trainer.
"""
if len(args) == 1:
arg = args[0]
......@@ -124,18 +129,16 @@ class ModelDescBase(object):
class ModelDesc(ModelDescBase):
"""
A ModelDesc with single cost and single optimizer.
A ModelDesc with **single cost** and **single optimizer**.
It contains information about InputDesc, how to get cost, and how to get optimizer.
"""
def get_cost(self):
"""
Return the cost tensor in the graph.
It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed.
Return the cost tensor to optimize on.
This function also applies the collection
This function takes the cost tensor defined by :meth:`build_graph`,
and applies the collection
``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically.
"""
cost = self._get_cost()
......@@ -165,6 +168,9 @@ class ModelDesc(ModelDescBase):
raise NotImplementedError()
def _build_graph_get_cost(self, *inputs):
"""
Used by trainers to get the final cost for optimization.
"""
self.build_graph(*inputs)
return self.get_cost()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment