Commit 55098813 authored by Yuxin Wu's avatar Yuxin Wu

Upgrade gym

parent 4bc0c748
...@@ -362,7 +362,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -362,7 +362,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
# include_init_with_doc doesn't work well for decorated init # include_init_with_doc doesn't work well for decorated init
# https://github.com/sphinx-doc/sphinx/issues/4258 # https://github.com/sphinx-doc/sphinx/issues/4258
return False return False
# hide deprecated stuff # Hide some names that are deprecated or not intended to be used
if name in [ if name in [
# deprecated stuff: # deprecated stuff:
'GaussianDeform', 'GaussianDeform',
......
### A3C code and models for Atari games in gym ### A3C code and models for Atari games in gym
Multi-GPU version of the A3C algorithm in Multi-GPU version of the A3C algorithm in
[Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783), [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783).
with <500 lines of code.
Results of the same code trained on 47 different Atari games were uploaded on OpenAI Gym. Results of the same code trained on 47 different Atari games were uploaded to OpenAI Gym.
You can see them in [my gym page](https://gym.openai.com/users/ppwwyyxx). Most of them were the best reproducible results on gym.
Most of them are the best reproducible results on gym. However OpenAI has later completely removed leaderboard from their site.
### To train on an Atari game: ### To train on an Atari game:
...@@ -17,10 +16,10 @@ The speed is about 6~10 iterations/s on 1 GPU plus 12+ CPU cores. ...@@ -17,10 +16,10 @@ The speed is about 6~10 iterations/s on 1 GPU plus 12+ CPU cores.
With 2 TitanX + 20+ CPU cores, by setting `SIMULATOR_PROC=240, PREDICT_BATCH_SIZE=30, PREDICTOR_THREAD_PER_GPU=6`, it can improve to 16 it/s (2K images/s). With 2 TitanX + 20+ CPU cores, by setting `SIMULATOR_PROC=240, PREDICT_BATCH_SIZE=30, PREDICTOR_THREAD_PER_GPU=6`, it can improve to 16 it/s (2K images/s).
Note that the network architecture is larger than what's used in the original paper. Note that the network architecture is larger than what's used in the original paper.
The uploaded models are all trained with 4 GPUs for about 2 days. The pretrained models are all trained with 4 GPUs for about 2 days.
But on simple games like Breakout, you can get good performance within several hours. But on simple games like Breakout, you can get good performance within several hours.
Also note that multi-GPU doesn't give you obvious speedup here, Also note that multi-GPU doesn't give you obvious speedup here,
because the bottleneck in this implementation is not computation but data. because the bottleneck in this implementation is not computation but simulation.
Some practicical notes: Some practicical notes:
...@@ -36,28 +35,32 @@ Download models from [model zoo](http://models.tensorpack.com/OpenAIGym/). ...@@ -36,28 +35,32 @@ Download models from [model zoo](http://models.tensorpack.com/OpenAIGym/).
Watch the agent play: Watch the agent play:
`./train-atari.py --task play --env Breakout-v0 --load Breakout-v0.npz` `./train-atari.py --task play --env Breakout-v0 --load Breakout-v0.npz`
Generate gym submissions: Dump some videos:
`./train-atari.py --task gen_submit --load Breakout-v0.npz --env Breakout-v0 --output output_dir` `./train-atari.py --task dump_video --load Breakout-v0.npz --env Breakout-v0 --output output_dir --episode 3`
Models are available for the following atari environments (click to watch videos of my agent): This table lists available pretrained models and scores (average over 100 episodes),
with their submission links.
The site is not maintained any more so the links might become invalid any time.
| | | | | | | | | |
| - | - | - | - | | - | - | - | - |
| [AirRaid](https://gym.openai.com/evaluations/eval_zIeNk5MxSGOmvGEUxrZDUw) | [Alien](https://gym.openai.com/evaluations/eval_8NR1IvjTQkSIT6En4xSMA) | [Amidar](https://gym.openai.com/evaluations/eval_HwEazbHtTYGpCialv9uPhA) | [Assault](https://gym.openai.com/evaluations/eval_tCiHwy5QrSdFVucSbBV6Q) | | [AirRaid](https://gym.openai.com/evaluations/eval_zIeNk5MxSGOmvGEUxrZDUw)(2727) | [Alien](https://gym.openai.com/evaluations/eval_8NR1IvjTQkSIT6En4xSMA) (2611) | [Amidar](https://gym.openai.com/evaluations/eval_HwEazbHtTYGpCialv9uPhA)(1376) | [Assault](https://gym.openai.com/evaluations/eval_tCiHwy5QrSdFVucSbBV6Q)(3397) |
| [Asterix](https://gym.openai.com/evaluations/eval_mees2c58QfKm5GspCjRfCA) | [Asteroids](https://gym.openai.com/evaluations/eval_8eHKsRL4RzuZEq9AOLZA) | [Atlantis](https://gym.openai.com/evaluations/eval_Z1B3d7A1QCaQk1HpO1Rg) | [BankHeist](https://gym.openai.com/evaluations/eval_hifoaxFTIuLlPd38BjnOw) | | [Asterix](https://gym.openai.com/evaluations/eval_mees2c58QfKm5GspCjRfCA)(407432) | [Asteroids](https://gym.openai.com/evaluations/eval_8eHKsRL4RzuZEq9AOLZA)(1965) | [Atlantis](https://gym.openai.com/evaluations/eval_Z1B3d7A1QCaQk1HpO1Rg)(217186) | [BankHeist](https://gym.openai.com/evaluations/eval_hifoaxFTIuLlPd38BjnOw)(1274) |
| [BattleZone](https://gym.openai.com/evaluations/eval_SoLit2bR1qmFoC0AsJF6Q) | [BeamRider](https://gym.openai.com/evaluations/eval_KuOYumrjQjixwL0spG0iCA) | [Berzerk](https://gym.openai.com/evaluations/eval_Yri0XQbwRy62NzWILdn5IA) | [Breakout](https://gym.openai.com/evaluations/eval_NiKaIN4NSUeEIvWqIgVDrA) | | [BattleZone](https://gym.openai.com/evaluations/eval_SoLit2bR1qmFoC0AsJF6Q)(29210) | [BeamRider](https://gym.openai.com/evaluations/eval_KuOYumrjQjixwL0spG0iCA)(5972) | [Berzerk](https://gym.openai.com/evaluations/eval_Yri0XQbwRy62NzWILdn5IA)(2289) | [Breakout](https://gym.openai.com/evaluations/eval_NiKaIN4NSUeEIvWqIgVDrA) (667) |
| [Carnival](https://gym.openai.com/evaluations/eval_xJSOlo2lSWaH1wHEOX5vw) | [Centipede](https://gym.openai.com/evaluations/eval_mc1Kp5e6R42rFdjeMLzkIg) | [ChopperCommand](https://gym.openai.com/evaluations/eval_tYVKyh7wQieRIKgEvVaCuw) | [CrazyClimber](https://gym.openai.com/evaluations/eval_bKeBg0QwSgOm6A0I0wDhSw) | | [Carnival](https://gym.openai.com/evaluations/eval_xJSOlo2lSWaH1wHEOX5vw)(5211) | [Centipede](https://gym.openai.com/evaluations/eval_mc1Kp5e6R42rFdjeMLzkIg)(2909) | [ChopperCommand](https://gym.openai.com/evaluations/eval_tYVKyh7wQieRIKgEvVaCuw)(6031) | [CrazyClimber](https://gym.openai.com/evaluations/eval_bKeBg0QwSgOm6A0I0wDhSw)(105297) |
| [DemonAttack](https://gym.openai.com/evaluations/eval_tt21vVaRCKYzWFcg1Kw) | [DoubleDunk](https://gym.openai.com/evaluations/eval_FI1GpF4TlCuf29KccTpQ) | [ElevatorAction](https://gym.openai.com/evaluations/eval_SqeAouMvR0icRivx2xprZg) | [FishingDerby](https://gym.openai.com/evaluations/eval_pPLCnFXsTVaayrIboDOs0g) | | [DemonAttack](https://gym.openai.com/evaluations/eval_tt21vVaRCKYzWFcg1Kw)(33992) | [DoubleDunk](https://gym.openai.com/evaluations/eval_FI1GpF4TlCuf29KccTpQ)(23) | [ElevatorAction](https://gym.openai.com/evaluations/eval_SqeAouMvR0icRivx2xprZg)(11377) | [FishingDerby](https://gym.openai.com/evaluations/eval_pPLCnFXsTVaayrIboDOs0g)(34) |
| [Frostbite](https://gym.openai.com/evaluations/eval_qtC3taKFSgWwkO9q9IM4hA) | [Gopher](https://gym.openai.com/evaluations/eval_KVcpR1YgQkEzrL2VIcAQ) | [Gravitar](https://gym.openai.com/evaluations/eval_QudrLdVmTpK9HF5juaZr0w) | [IceHockey](https://gym.openai.com/evaluations/eval_8oWCTwwGS7OUTTGRwBPQkQ) | | [Frostbite](https://gym.openai.com/evaluations/eval_qtC3taKFSgWwkO9q9IM4hA)(6824) | [Gopher](https://gym.openai.com/evaluations/eval_KVcpR1YgQkEzrL2VIcAQ)(22595) | [Gravitar](https://gym.openai.com/evaluations/eval_QudrLdVmTpK9HF5juaZr0w)(2144) | [IceHockey](https://gym.openai.com/evaluations/eval_8oWCTwwGS7OUTTGRwBPQkQ)(19) |
| [Jamesbond](https://gym.openai.com/evaluations/eval_mLF7XPi8Tw66pnjP73JsmA) | [JourneyEscape](https://gym.openai.com/evaluations/eval_S9nQuXLRSu7S5x21Ay6AA) | [Kangaroo](https://gym.openai.com/evaluations/eval_TNJiLB8fTqOPfvINnPXoQ) | [Krull](https://gym.openai.com/evaluations/eval_dfOS2WzhTh6sn1FuPS9HA) | | [Jamesbond](https://gym.openai.com/evaluations/eval_mLF7XPi8Tw66pnjP73JsmA)(640) | [JourneyEscape](https://gym.openai.com/evaluations/eval_S9nQuXLRSu7S5x21Ay6AA)(-407) | [Kangaroo](https://gym.openai.com/evaluations/eval_TNJiLB8fTqOPfvINnPXoQ)(6540) | [Krull](https://gym.openai.com/evaluations/eval_dfOS2WzhTh6sn1FuPS9HA)(6100) |
| [KungFuMaster](https://gym.openai.com/evaluations/eval_vNWDShYTRC0MhfIybeUYg) | [MsPacman](https://gym.openai.com/evaluations/eval_kpL9bSsS4GXsYb9HuEfew) | [NameThisGame](https://gym.openai.com/evaluations/eval_LZqfv706SdOMtR4ZZIwIsg) | [Phoenix](https://gym.openai.com/evaluations/eval_uzUruiB3RRKUMvJIxvEzYA) | | [KungFuMaster](https://gym.openai.com/evaluations/eval_vNWDShYTRC0MhfIybeUYg)(34767) | [MsPacman](https://gym.openai.com/evaluations/eval_kpL9bSsS4GXsYb9HuEfew)(5738) | [NameThisGame](https://gym.openai.com/evaluations/eval_LZqfv706SdOMtR4ZZIwIsg)(15321) | [Phoenix](https://gym.openai.com/evaluations/eval_uzUruiB3RRKUMvJIxvEzYA)(75312) |
| [Pong](https://gym.openai.com/evaluations/eval_8L7SV59nSW6GGbbP3N4G6w) | [Pooyan](https://gym.openai.com/evaluations/eval_UXFVI34MSAuNTtjZcK8N0A) | [Qbert](https://gym.openai.com/evaluations/eval_S8XdrbByQ1eWLUD5jtQYIQ) | [Riverraid](https://gym.openai.com/evaluations/eval_OU4x3DkTfm4uaXy6CIaXg) | | [Pong](https://gym.openai.com/evaluations/eval_8L7SV59nSW6GGbbP3N4G6w)(21) | [Pooyan](https://gym.openai.com/evaluations/eval_UXFVI34MSAuNTtjZcK8N0A)(5607) | [Qbert](https://gym.openai.com/evaluations/eval_S8XdrbByQ1eWLUD5jtQYIQ)(20182) | [Riverraid](https://gym.openai.com/evaluations/eval_OU4x3DkTfm4uaXy6CIaXg)(14185) |
| [RoadRunner](https://gym.openai.com/evaluations/eval_wINKQTwxT9ipydHOXBhg) | [Robotank](https://gym.openai.com/evaluations/eval_Gr5c0ld3QACLDPQrGdzbiw) | [Seaquest](https://gym.openai.com/evaluations/eval_pjjgc9POQJK4IuVw8nXlBw) | [SpaceInvaders](https://gym.openai.com/evaluations/eval_Eduozx4HRyqgTCVk9ltw) | | [RoadRunner](https://gym.openai.com/evaluations/eval_wINKQTwxT9ipydHOXBhg)(60615) | [Robotank](https://gym.openai.com/evaluations/eval_Gr5c0ld3QACLDPQrGdzbiw)(60) | [Seaquest](https://gym.openai.com/evaluations/eval_pjjgc9POQJK4IuVw8nXlBw)(46890) | SpaceInvaders(3454) |
| [StarGunner](https://gym.openai.com/evaluations/eval_JB5cOJXFSS2cTQ7dXK8Iag) | [Tennis](https://gym.openai.com/evaluations/eval_gDjJD0MMS1yLm1T0hdqI4g) | [Tutankham](https://gym.openai.com/evaluations/eval_gDjJD0MMS1yLm1T0hdqI4g) | [UpNDown](https://gym.openai.com/evaluations/eval_KmkvMJkxQFSED20wFUMdIA) | | [StarGunner](https://gym.openai.com/evaluations/eval_JB5cOJXFSS2cTQ7dXK8Iag)(93480) | [Tennis](https://gym.openai.com/evaluations/eval_gDjJD0MMS1yLm1T0hdqI4g)(23) | Tutankham(275) | [UpNDown](https://gym.openai.com/evaluations/eval_KmkvMJkxQFSED20wFUMdIA)(92163) |
| [VideoPinball](https://gym.openai.com/evaluations/eval_PWwzNhVFR2CxjYvEsPfT1g) | [WizardOfWor](https://gym.openai.com/evaluations/eval_1oGQhphpQhmzEMIYRrrp0A) | [Zaxxon](https://gym.openai.com/evaluations/eval_TIQ102EwTrHrOyve2RGfg) | | | [VideoPinball](https://gym.openai.com/evaluations/eval_PWwzNhVFR2CxjYvEsPfT1g)(140156) | [WizardOfWor](https://gym.openai.com/evaluations/eval_1oGQhphpQhmzEMIYRrrp0A)(3824) | [Zaxxon](https://gym.openai.com/evaluations/eval_TIQ102EwTrHrOyve2RGfg)(32894) | |
Note that atari game settings in gym (AtariGames-v0) are quite different from DeepMind papers, so the scores are not comparable. The most notable differences are: All models above are trained with the `-v0` variant of atari games.
Note that this variant is quite different from DeepMind papers, so the scores are not directly comparable.
The most notable differences are:
+ Each action is randomly repeated 2~4 times. + Each action is randomly repeated 2~4 times.
+ Inputs are RGB instead of greyscale. + Inputs are RGB instead of greyscale.
+ An episode is limited to 10000 steps. + An episode is limited to 10000 steps.
......
...@@ -54,7 +54,7 @@ ENV_NAME = None ...@@ -54,7 +54,7 @@ ENV_NAME = None
def get_player(train=False, dumpdir=None): def get_player(train=False, dumpdir=None):
env = gym.make(ENV_NAME) env = gym.make(ENV_NAME)
if dumpdir: if dumpdir:
env = gym.wrappers.Monitor(env, dumpdir) env = gym.wrappers.Monitor(env, dumpdir, video_callable=lambda _: True)
env = FireResetEnv(env) env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)) env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
env = FrameStack(env, 4) env = FrameStack(env, 4)
...@@ -272,7 +272,7 @@ if __name__ == '__main__': ...@@ -272,7 +272,7 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--env', help='env', required=True) parser.add_argument('--env', help='env', required=True)
parser.add_argument('--task', help='task to perform', parser.add_argument('--task', help='task to perform',
choices=['play', 'eval', 'train', 'gen_submit'], default='train') choices=['play', 'eval', 'train', 'dump_video'], default='train')
parser.add_argument('--output', help='output directory for submission', default='output_dir') parser.add_argument('--output', help='output directory for submission', default='output_dir')
parser.add_argument('--episode', help='number of episode to eval', default=100, type=int) parser.add_argument('--episode', help='number of episode to eval', default=100, type=int)
args = parser.parse_args() args = parser.parse_args()
...@@ -297,10 +297,9 @@ if __name__ == '__main__': ...@@ -297,10 +297,9 @@ if __name__ == '__main__':
args.episode, render=True) args.episode, render=True)
elif args.task == 'eval': elif args.task == 'eval':
eval_model_multithread(pred, args.episode, get_player) eval_model_multithread(pred, args.episode, get_player)
elif args.task == 'gen_submit': elif args.task == 'dump_video':
play_n_episodes( play_n_episodes(
get_player(train=False, dumpdir=args.output), get_player(train=False, dumpdir=args.output),
pred, args.episode) pred, args.episode)
# gym.upload(args.output, api_key='xxx')
else: else:
train() train()
...@@ -9,7 +9,7 @@ import tensorpack ...@@ -9,7 +9,7 @@ import tensorpack
from tensorpack import ModelDesc, InputDesc from tensorpack import ModelDesc, InputDesc
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.tfutils import ( from tensorpack.tfutils import (
summary, get_current_tower_context, optimizer, gradproc) varreplace, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
assert tensorpack.tfutils.common.get_tf_version_number() >= 1.2 assert tensorpack.tfutils.common.get_tf_version_number() >= 1.2
...@@ -60,7 +60,7 @@ class Model(ModelDesc): ...@@ -60,7 +60,7 @@ class Model(ModelDesc):
self.predict_value, 1), name='predict_reward') self.predict_value, 1), name='predict_reward')
summary.add_moving_summary(max_pred_reward) summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'): with tf.variable_scope('target'), varreplace.freeze_variables(skip_collection=True):
targetQ_predict_value = self.get_DQN_prediction(next_state) # NxA targetQ_predict_value = self.get_DQN_prediction(next_state) # NxA
if self.method != 'Double': if self.method != 'Double':
...@@ -96,6 +96,6 @@ class Model(ModelDesc): ...@@ -96,6 +96,6 @@ class Model(ModelDesc):
target_name = v.op.name target_name = v.op.name
if target_name.startswith('target'): if target_name.startswith('target'):
new_name = target_name.replace('target/', '') new_name = target_name.replace('target/', '')
logger.info("{} <- {}".format(target_name, new_name)) logger.info("Target Network Update: {} <- {}".format(target_name, new_name))
ops.append(v.assign(G.get_tensor_by_name(new_name + ':0'))) ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network') return tf.group(*ops, name='update_target_network')
...@@ -138,12 +138,12 @@ class AtariPlayer(gym.Env): ...@@ -138,12 +138,12 @@ class AtariPlayer(gym.Env):
self.last_raw_screen = self._grab_raw_image() self.last_raw_screen = self._grab_raw_image()
self.ale.act(0) self.ale.act(0)
def _reset(self): def reset(self):
if self.ale.game_over(): if self.ale.game_over():
self._restart_episode() self._restart_episode()
return self._current_state() return self._current_state()
def _step(self, act): def step(self, act):
oldlives = self.ale.lives() oldlives = self.ale.lives()
r = 0 r = 0
for k in range(self.frame_skip): for k in range(self.frame_skip):
......
...@@ -8,6 +8,9 @@ from collections import deque ...@@ -8,6 +8,9 @@ from collections import deque
import gym import gym
from gym import spaces from gym import spaces
_v0, _v1 = gym.__version__.split('.')[:2]
assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__
""" """
The following wrappers are copied or modified from openai/baselines: The following wrappers are copied or modified from openai/baselines:
...@@ -20,7 +23,7 @@ class MapState(gym.ObservationWrapper): ...@@ -20,7 +23,7 @@ class MapState(gym.ObservationWrapper):
gym.ObservationWrapper.__init__(self, env) gym.ObservationWrapper.__init__(self, env)
self._func = map_func self._func = map_func
def _observation(self, obs): def observation(self, obs):
return self._func(obs) return self._func(obs)
...@@ -32,22 +35,23 @@ class FrameStack(gym.Wrapper): ...@@ -32,22 +35,23 @@ class FrameStack(gym.Wrapper):
self.frames = deque([], maxlen=k) self.frames = deque([], maxlen=k)
shp = env.observation_space.shape shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2] chan = 1 if len(shp) == 2 else shp[2]
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k)) self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], chan * k), dtype=np.uint8)
def _reset(self): def reset(self):
"""Clear buffer and re-fill by duplicating the first observation.""" """Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset() ob = self.env.reset()
for _ in range(self.k - 1): for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob)) self.frames.append(np.zeros_like(ob))
self.frames.append(ob) self.frames.append(ob)
return self._observation() return self.observation()
def _step(self, action): def step(self, action):
ob, reward, done, info = self.env.step(action) ob, reward, done, info = self.env.step(action)
self.frames.append(ob) self.frames.append(ob)
return self._observation(), reward, done, info return self.observation(), reward, done, info
def _observation(self): def observation(self):
assert len(self.frames) == self.k assert len(self.frames) == self.k
if self.frames[-1].ndim == 2: if self.frames[-1].ndim == 2:
return np.stack(self.frames, axis=-1) return np.stack(self.frames, axis=-1)
...@@ -62,7 +66,7 @@ class _FireResetEnv(gym.Wrapper): ...@@ -62,7 +66,7 @@ class _FireResetEnv(gym.Wrapper):
assert env.unwrapped.get_action_meanings()[1] == 'FIRE' assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3 assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self): def reset(self):
self.env.reset() self.env.reset()
obs, _, done, _ = self.env.step(1) obs, _, done, _ = self.env.step(1)
if done: if done:
...@@ -72,6 +76,9 @@ class _FireResetEnv(gym.Wrapper): ...@@ -72,6 +76,9 @@ class _FireResetEnv(gym.Wrapper):
self.env.reset() self.env.reset()
return obs return obs
def step(self, action):
return self.env.step(action)
def FireResetEnv(env): def FireResetEnv(env):
if isinstance(env, gym.Wrapper): if isinstance(env, gym.Wrapper):
...@@ -88,7 +95,7 @@ class LimitLength(gym.Wrapper): ...@@ -88,7 +95,7 @@ class LimitLength(gym.Wrapper):
gym.Wrapper.__init__(self, env) gym.Wrapper.__init__(self, env)
self.k = k self.k = k
def _reset(self): def reset(self):
# This assumes that reset() will really reset the env. # This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset # If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold. # (e.g. end-of-life), the assumption doesn't hold.
...@@ -96,7 +103,7 @@ class LimitLength(gym.Wrapper): ...@@ -96,7 +103,7 @@ class LimitLength(gym.Wrapper):
self.cnt = 0 self.cnt = 0
return ob return ob
def _step(self, action): def step(self, action):
ob, r, done, info = self.env.step(action) ob, r, done, info = self.env.step(action)
self.cnt += 1 self.cnt += 1
if self.cnt == self.k: if self.cnt == self.k:
......
...@@ -18,7 +18,7 @@ from tensorpack.utils.utils import get_tqdm_kwargs ...@@ -18,7 +18,7 @@ from tensorpack.utils.utils import get_tqdm_kwargs
def play_one_episode(env, func, render=False): def play_one_episode(env, func, render=False):
def predict(s): def predict(s):
""" """
Map from observation to action, with 0.001 greedy. Map from observation to action, with 0.01 greedy.
""" """
act = func(s[None, :, :, :])[0][0].argmax() act = func(s[None, :, :, :])[0][0].argmax()
if random.random() < 0.01: if random.random() < 0.01:
...@@ -45,7 +45,7 @@ def play_n_episodes(player, predfunc, nr, render=False): ...@@ -45,7 +45,7 @@ def play_n_episodes(player, predfunc, nr, render=False):
print("{}/{}, score={}".format(k, nr, score)) print("{}/{}, score={}".format(k, nr, score))
def eval_with_funcs(predictors, nr_eval, get_player_fn): def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
""" """
Args: Args:
predictors ([PredictorBase]) predictors ([PredictorBase])
...@@ -67,7 +67,6 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn): ...@@ -67,7 +67,6 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
while not self.stopped(): while not self.stopped():
try: try:
score = play_one_episode(player, self.func) score = play_one_episode(player, self.func)
# print("Score, ", score)
except RuntimeError: except RuntimeError:
return return
self.queue_put_stoppable(self.q, score) self.queue_put_stoppable(self.q, score)
...@@ -80,17 +79,21 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn): ...@@ -80,17 +79,21 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
time.sleep(0.1) # avoid simulator bugs time.sleep(0.1) # avoid simulator bugs
stat = StatCounter() stat = StatCounter()
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()): def fetch():
r = q.get() r = q.get()
stat.feed(r) stat.feed(r)
if verbose:
logger.info("Score: {}".format(r))
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
fetch()
logger.info("Waiting for all the workers to finish the last run...") logger.info("Waiting for all the workers to finish the last run...")
for k in threads: for k in threads:
k.stop() k.stop()
for k in threads: for k in threads:
k.join() k.join()
while q.qsize(): while q.qsize():
r = q.get() fetch()
stat.feed(r)
if stat.count > 0: if stat.count > 0:
return (stat.average, stat.max) return (stat.average, stat.max)
...@@ -100,11 +103,13 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn): ...@@ -100,11 +103,13 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
def eval_model_multithread(pred, nr_eval, get_player_fn): def eval_model_multithread(pred, nr_eval, get_player_fn):
""" """
Args: Args:
pred (OfflinePredictor): state -> Qvalue pred (OfflinePredictor): state -> [#action]
""" """
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
with pred.sess.as_default(): with pred.sess.as_default():
mean, max = eval_with_funcs([pred] * NR_PROC, nr_eval, get_player_fn) mean, max = eval_with_funcs(
[pred] * NR_PROC, nr_eval,
get_player_fn, verbose=True)
logger.info("Average Score: {}; Max Score: {}".format(mean, max)) logger.info("Average Score: {}; Max Score: {}".format(mean, max))
......
...@@ -95,6 +95,11 @@ class ModelDescBase(object): ...@@ -95,6 +95,11 @@ class ModelDescBase(object):
Args: Args:
args ([tf.Tensor]): tensors that matches the list of args ([tf.Tensor]): tensors that matches the list of
:class:`InputDesc` defined by ``_get_inputs``. :class:`InputDesc` defined by ``_get_inputs``.
Returns:
In general it returns nothing, but a subclass (e.g.
:class:`ModelDesc` may require it to return necessary information
to build the trainer.
""" """
if len(args) == 1: if len(args) == 1:
arg = args[0] arg = args[0]
...@@ -124,18 +129,16 @@ class ModelDescBase(object): ...@@ -124,18 +129,16 @@ class ModelDescBase(object):
class ModelDesc(ModelDescBase): class ModelDesc(ModelDescBase):
""" """
A ModelDesc with single cost and single optimizer. A ModelDesc with **single cost** and **single optimizer**.
It contains information about InputDesc, how to get cost, and how to get optimizer. It contains information about InputDesc, how to get cost, and how to get optimizer.
""" """
def get_cost(self): def get_cost(self):
""" """
Return the cost tensor in the graph. Return the cost tensor to optimize on.
It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed.
This function also applies the collection This function takes the cost tensor defined by :meth:`build_graph`,
and applies the collection
``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically. ``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically.
""" """
cost = self._get_cost() cost = self._get_cost()
...@@ -165,6 +168,9 @@ class ModelDesc(ModelDescBase): ...@@ -165,6 +168,9 @@ class ModelDesc(ModelDescBase):
raise NotImplementedError() raise NotImplementedError()
def _build_graph_get_cost(self, *inputs): def _build_graph_get_cost(self, *inputs):
"""
Used by trainers to get the final cost for optimization.
"""
self.build_graph(*inputs) self.build_graph(*inputs)
return self.get_cost() return self.get_cost()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment