Commit 634369d8 authored by Yuxin Wu's avatar Yuxin Wu

split atari_wrapper from common

parent 7e963996
...@@ -8,6 +8,8 @@ so you won't need to look at here very often. ...@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here. TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e).
`tensorpack.RL` was deprecated. The RL examples are written with OpenAI gym interface instead.
+ [2017/10/10](https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc). + [2017/10/10](https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc).
`tfutils.distributions` was deprecated in favor of `tf.distributions` introduced in TF 1.3. `tfutils.distributions` was deprecated in favor of `tf.distributions` introduced in TF 1.3.
+ [2017/08/02](https://github.com/ppwwyyxx/tensorpack/commit/875f4d7dbb5675f54eae5675fa3a0948309a8465). + [2017/08/02](https://github.com/ppwwyyxx/tensorpack/commit/875f4d7dbb5675f54eae5675fa3a0948309a8465).
......
...@@ -57,7 +57,6 @@ Models are available for the following atari environments (click to watch videos ...@@ -57,7 +57,6 @@ Models are available for the following atari environments (click to watch videos
| [VideoPinball](https://gym.openai.com/evaluations/eval_PWwzNhVFR2CxjYvEsPfT1g) | [WizardOfWor](https://gym.openai.com/evaluations/eval_1oGQhphpQhmzEMIYRrrp0A) | [Zaxxon](https://gym.openai.com/evaluations/eval_TIQ102EwTrHrOyve2RGfg) | | | [VideoPinball](https://gym.openai.com/evaluations/eval_PWwzNhVFR2CxjYvEsPfT1g) | [WizardOfWor](https://gym.openai.com/evaluations/eval_1oGQhphpQhmzEMIYRrrp0A) | [Zaxxon](https://gym.openai.com/evaluations/eval_TIQ102EwTrHrOyve2RGfg) | |
Note that atari game settings in gym (AtariGames-v0) are quite different from DeepMind papers, so the scores are not comparable. The most notable differences are: Note that atari game settings in gym (AtariGames-v0) are quite different from DeepMind papers, so the scores are not comparable. The most notable differences are:
+ Each action is randomly repeated 2~4 times. + Each action is randomly repeated 2~4 times.
+ Inputs are RGB instead of greyscale. + Inputs are RGB instead of greyscale.
......
../DeepQNetwork/atari_wrapper.py
\ No newline at end of file
...@@ -29,10 +29,9 @@ from tensorpack.utils.gpu import get_nr_gpu ...@@ -29,10 +29,9 @@ from tensorpack.utils.gpu import get_nr_gpu
import gym import gym
from simulator import * from simulator import *
import common
from common import (Evaluator, eval_model_multithread, from common import (Evaluator, eval_model_multithread,
play_one_episode, play_n_episodes, play_one_episode, play_n_episodes)
WarpFrame, FrameStack, FireResetEnv, LimitLength) from atari_wrapper import WarpFrame, FrameStack, FireResetEnv, LimitLength
if six.PY3: if six.PY3:
from concurrent import futures from concurrent import futures
......
...@@ -21,9 +21,8 @@ from tensorpack.utils.concurrency import * ...@@ -21,9 +21,8 @@ from tensorpack.utils.concurrency import *
import tensorflow as tf import tensorflow as tf
from DQNModel import Model as DQNModel from DQNModel import Model as DQNModel
import common
from common import Evaluator, eval_model_multithread, play_n_episodes from common import Evaluator, eval_model_multithread, play_n_episodes
from common import FrameStack, WarpFrame, FireResetEnv from atari_wrapper import FrameStack, WarpFrame, FireResetEnv
from expreplay import ExpReplay from expreplay import ExpReplay
from atari import AtariPlayer from atari import AtariPlayer
......
...@@ -27,6 +27,8 @@ Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game fr ...@@ -27,6 +27,8 @@ Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game fr
## How to use ## How to use
Install [ALE](https://github.com/mgbellemare/Arcade-Learning-Environment) and gym.
Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to
`$TENSORPACK_DATASET/atari_rom/` (defaults to ~/tensorpack_data/atari_rom/), e.g.: `$TENSORPACK_DATASET/atari_rom/` (defaults to ~/tensorpack_data/atari_rom/), e.g.:
``` ```
...@@ -42,7 +44,7 @@ Start Training: ...@@ -42,7 +44,7 @@ Start Training:
Watch the agent play: Watch the agent play:
``` ```
./DQN.py --rom breakout.bin --task play --load trained.model ./DQN.py --rom breakout.bin --task play --load path/to/model
``` ```
A pretrained model on breakout can be downloaded [here](https://drive.google.com/open?id=0B9IPQTvr2BBkN1Jrei1xWW0yR28). A pretrained model on breakout can be downloaded [here](https://drive.google.com/open?id=0B9IPQTvr2BBkN1Jrei1xWW0yR28).
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: atari_wrapper.py
import numpy as np
import cv2
from collections import deque
import gym
from gym import spaces
"""
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, shape):
gym.ObservationWrapper.__init__(self, env)
self.shape = shape
obs = env.observation_space
assert isinstance(obs, spaces.Box)
chan = 1 if len(obs.shape) == 2 else obs.shape[2]
shape3d = shape if chan == 1 else shape + (chan,)
self.observation_space = spaces.Box(low=0, high=255, shape=shape3d)
def _observation(self, obs):
return cv2.resize(obs, self.shape)
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self._base_dim = len(shp)
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k))
def _reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob))
self.frames.append(ob)
return self._observation()
def _step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._observation(), reward, done, info
def _observation(self):
assert len(self.frames) == self.k
if self._base_dim == 2:
return np.stack(self.frames, axis=-1)
else:
return np.concatenate(self.frames, axis=2)
class _FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
def FireResetEnv(env):
if isinstance(env, gym.Wrapper):
baseenv = env.unwrapped
else:
baseenv = env
if 'FIRE' in baseenv.get_action_meanings():
return _FireResetEnv(env)
return env
class LimitLength(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self.k = k
def _reset(self):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob = self.env.reset()
self.cnt = 0
return ob
def _step(self, action):
ob, r, done, info = self.env.step(action)
self.cnt += 1
if self.cnt == self.k:
done = True
return ob, r, done, info
...@@ -4,17 +4,10 @@ ...@@ -4,17 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import random import random
import time import time
import threading
import multiprocessing import multiprocessing
import numpy as np
import cv2
from collections import deque
from tqdm import tqdm from tqdm import tqdm
from six.moves import queue from six.moves import queue
import gym
from gym import spaces
from tensorpack.utils.concurrency import StoppableThread, ShareSessionThread from tensorpack.utils.concurrency import StoppableThread, ShareSessionThread
from tensorpack.callbacks import Triggerable from tensorpack.callbacks import Triggerable
from tensorpack.utils import logger from tensorpack.utils import logger
...@@ -138,105 +131,3 @@ class Evaluator(Triggerable): ...@@ -138,105 +131,3 @@ class Evaluator(Triggerable):
self.eval_episode = int(self.eval_episode * 0.94) self.eval_episode = int(self.eval_episode * 0.94)
self.trainer.monitors.put_scalar('mean_score', mean) self.trainer.monitors.put_scalar('mean_score', mean)
self.trainer.monitors.put_scalar('max_score', max) self.trainer.monitors.put_scalar('max_score', max)
"""
------------------------------------------------------------------------------
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, shape):
gym.ObservationWrapper.__init__(self, env)
self.shape = shape
obs = env.observation_space
assert isinstance(obs, spaces.Box)
chan = 1 if len(obs.shape) == 2 else obs.shape[2]
shape3d = shape if chan == 1 else shape + (chan,)
self.observation_space = spaces.Box(low=0, high=255, shape=shape3d)
def _observation(self, obs):
return cv2.resize(obs, self.shape)
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self._base_dim = len(shp)
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k))
def _reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob))
self.frames.append(ob)
return self._observation()
def _step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._observation(), reward, done, info
def _observation(self):
assert len(self.frames) == self.k
if self._base_dim == 2:
return np.stack(self.frames, axis=-1)
else:
return np.concatenate(self.frames, axis=2)
class _FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
def FireResetEnv(env):
if isinstance(env, gym.Wrapper):
baseenv = env.unwrapped
else:
baseenv = env
if 'FIRE' in baseenv.get_action_meanings():
return _FireResetEnv(env)
return env
class LimitLength(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self.k = k
def _reset(self):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob = self.env.reset()
self.cnt = 0
return ob
def _step(self, action):
ob, r, done, info = self.env.step(action)
self.cnt += 1
if self.cnt == self.k:
done = True
return ob, r, done, info
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment