Commit 634369d8 authored by Yuxin Wu's avatar Yuxin Wu

split atari_wrapper from common

parent 7e963996
......@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e).
`tensorpack.RL` was deprecated. The RL examples are written with OpenAI gym interface instead.
+ [2017/10/10](https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc).
`tfutils.distributions` was deprecated in favor of `tf.distributions` introduced in TF 1.3.
+ [2017/08/02](https://github.com/ppwwyyxx/tensorpack/commit/875f4d7dbb5675f54eae5675fa3a0948309a8465).
......
......@@ -57,7 +57,6 @@ Models are available for the following atari environments (click to watch videos
| [VideoPinball](https://gym.openai.com/evaluations/eval_PWwzNhVFR2CxjYvEsPfT1g) | [WizardOfWor](https://gym.openai.com/evaluations/eval_1oGQhphpQhmzEMIYRrrp0A) | [Zaxxon](https://gym.openai.com/evaluations/eval_TIQ102EwTrHrOyve2RGfg) | |
Note that atari game settings in gym (AtariGames-v0) are quite different from DeepMind papers, so the scores are not comparable. The most notable differences are:
+ Each action is randomly repeated 2~4 times.
+ Inputs are RGB instead of greyscale.
......
../DeepQNetwork/atari_wrapper.py
\ No newline at end of file
......@@ -29,10 +29,9 @@ from tensorpack.utils.gpu import get_nr_gpu
import gym
from simulator import *
import common
from common import (Evaluator, eval_model_multithread,
play_one_episode, play_n_episodes,
WarpFrame, FrameStack, FireResetEnv, LimitLength)
play_one_episode, play_n_episodes)
from atari_wrapper import WarpFrame, FrameStack, FireResetEnv, LimitLength
if six.PY3:
from concurrent import futures
......
......@@ -21,9 +21,8 @@ from tensorpack.utils.concurrency import *
import tensorflow as tf
from DQNModel import Model as DQNModel
import common
from common import Evaluator, eval_model_multithread, play_n_episodes
from common import FrameStack, WarpFrame, FireResetEnv
from atari_wrapper import FrameStack, WarpFrame, FireResetEnv
from expreplay import ExpReplay
from atari import AtariPlayer
......
......@@ -27,6 +27,8 @@ Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game fr
## How to use
Install [ALE](https://github.com/mgbellemare/Arcade-Learning-Environment) and gym.
Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to
`$TENSORPACK_DATASET/atari_rom/` (defaults to ~/tensorpack_data/atari_rom/), e.g.:
```
......@@ -42,7 +44,7 @@ Start Training:
Watch the agent play:
```
./DQN.py --rom breakout.bin --task play --load trained.model
./DQN.py --rom breakout.bin --task play --load path/to/model
```
A pretrained model on breakout can be downloaded [here](https://drive.google.com/open?id=0B9IPQTvr2BBkN1Jrei1xWW0yR28).
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: atari_wrapper.py
import numpy as np
import cv2
from collections import deque
import gym
from gym import spaces
"""
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, shape):
gym.ObservationWrapper.__init__(self, env)
self.shape = shape
obs = env.observation_space
assert isinstance(obs, spaces.Box)
chan = 1 if len(obs.shape) == 2 else obs.shape[2]
shape3d = shape if chan == 1 else shape + (chan,)
self.observation_space = spaces.Box(low=0, high=255, shape=shape3d)
def _observation(self, obs):
return cv2.resize(obs, self.shape)
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self._base_dim = len(shp)
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k))
def _reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob))
self.frames.append(ob)
return self._observation()
def _step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._observation(), reward, done, info
def _observation(self):
assert len(self.frames) == self.k
if self._base_dim == 2:
return np.stack(self.frames, axis=-1)
else:
return np.concatenate(self.frames, axis=2)
class _FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
def FireResetEnv(env):
if isinstance(env, gym.Wrapper):
baseenv = env.unwrapped
else:
baseenv = env
if 'FIRE' in baseenv.get_action_meanings():
return _FireResetEnv(env)
return env
class LimitLength(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self.k = k
def _reset(self):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob = self.env.reset()
self.cnt = 0
return ob
def _step(self, action):
ob, r, done, info = self.env.step(action)
self.cnt += 1
if self.cnt == self.k:
done = True
return ob, r, done, info
......@@ -4,17 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import random
import time
import threading
import multiprocessing
import numpy as np
import cv2
from collections import deque
from tqdm import tqdm
from six.moves import queue
import gym
from gym import spaces
from tensorpack.utils.concurrency import StoppableThread, ShareSessionThread
from tensorpack.callbacks import Triggerable
from tensorpack.utils import logger
......@@ -138,105 +131,3 @@ class Evaluator(Triggerable):
self.eval_episode = int(self.eval_episode * 0.94)
self.trainer.monitors.put_scalar('mean_score', mean)
self.trainer.monitors.put_scalar('max_score', max)
"""
------------------------------------------------------------------------------
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, shape):
gym.ObservationWrapper.__init__(self, env)
self.shape = shape
obs = env.observation_space
assert isinstance(obs, spaces.Box)
chan = 1 if len(obs.shape) == 2 else obs.shape[2]
shape3d = shape if chan == 1 else shape + (chan,)
self.observation_space = spaces.Box(low=0, high=255, shape=shape3d)
def _observation(self, obs):
return cv2.resize(obs, self.shape)
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self._base_dim = len(shp)
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k))
def _reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob))
self.frames.append(ob)
return self._observation()
def _step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._observation(), reward, done, info
def _observation(self):
assert len(self.frames) == self.k
if self._base_dim == 2:
return np.stack(self.frames, axis=-1)
else:
return np.concatenate(self.frames, axis=2)
class _FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
def FireResetEnv(env):
if isinstance(env, gym.Wrapper):
baseenv = env.unwrapped
else:
baseenv = env
if 'FIRE' in baseenv.get_action_meanings():
return _FireResetEnv(env)
return env
class LimitLength(gym.Wrapper):
def __init__(self, env, k):
gym.Wrapper.__init__(self, env)
self.k = k
def _reset(self):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob = self.env.reset()
self.cnt = 0
return ob
def _step(self, action):
ob, r, done, info = self.env.step(action)
self.cnt += 1
if self.cnt == self.k:
done = True
return ob, r, done, info
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment