Commit 0dbcbac7 authored by Yuxin Wu's avatar Yuxin Wu

DQN supports gym as well.

parent 3aab66f1
...@@ -8,18 +8,19 @@ import argparse ...@@ -8,18 +8,19 @@ import argparse
import cv2 import cv2
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import gym
from tensorpack import * from tensorpack import *
from DQNModel import Model as DQNModel from DQNModel import Model as DQNModel
from common import Evaluator, eval_model_multithread, play_n_episodes from common import Evaluator, eval_model_multithread, play_n_episodes
from atari_wrapper import FrameStack, MapState, FireResetEnv from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength
from expreplay import ExpReplay from expreplay import ExpReplay
from atari import AtariPlayer from atari import AtariPlayer
BATCH_SIZE = 64 BATCH_SIZE = 64
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
IMAGE_CHANNEL = None # 3 in gym and 1 in our own wrapper
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 4 # aka FRAME_SKIP ACTION_REPEAT = 4 # aka FRAME_SKIP
UPDATE_FREQ = 4 UPDATE_FREQ = 4
...@@ -33,24 +34,39 @@ STEPS_PER_EPOCH = 100000 // UPDATE_FREQ # each epoch is 100k played frames ...@@ -33,24 +34,39 @@ STEPS_PER_EPOCH = 100000 // UPDATE_FREQ # each epoch is 100k played frames
EVAL_EPISODE = 50 EVAL_EPISODE = 50
NUM_ACTIONS = None NUM_ACTIONS = None
ROM_FILE = None USE_GYM = False
ENV_NAME = None
METHOD = None METHOD = None
def resize_keepdims(im, size):
# Opencv's resize remove the extra dimension for grayscale images.
# We add it back.
ret = cv2.resize(im, size)
if im.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
def get_player(viz=False, train=False): def get_player(viz=False, train=False):
env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz, if USE_GYM:
env = gym.make(ENV_NAME)
else:
env = AtariPlayer(ENV_NAME, frame_skip=ACTION_REPEAT, viz=viz,
live_lost_as_eoe=train, max_num_frames=60000) live_lost_as_eoe=train, max_num_frames=60000)
env = FireResetEnv(env) env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)[:, :, np.newaxis]) env = MapState(env, lambda im: resize_keepdims(im, IMAGE_SIZE))
if not train: if not train:
# in training, history is taken care of in expreplay buffer # in training, history is taken care of in expreplay buffer
env = FrameStack(env, FRAME_HISTORY) env = FrameStack(env, FRAME_HISTORY)
if train and USE_GYM:
env = LimitLength(env, 60000)
return env return env
class Model(DQNModel): class Model(DQNModel):
def __init__(self): def __init__(self):
super(Model, self).__init__(IMAGE_SIZE, 1, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA) super(Model, self).__init__(IMAGE_SIZE, IMAGE_CHANNEL, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA)
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
image = image / 255.0 image = image / 255.0
...@@ -86,7 +102,7 @@ def get_config(): ...@@ -86,7 +102,7 @@ def get_config():
expreplay = ExpReplay( expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']), predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True), player=get_player(train=True),
state_shape=IMAGE_SIZE + (1,), state_shape=IMAGE_SIZE + (IMAGE_CHANNEL,),
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE, init_memory_size=INIT_MEMORY_SIZE,
...@@ -126,18 +142,21 @@ if __name__ == '__main__': ...@@ -126,18 +142,21 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--task', help='task to perform', parser.add_argument('--task', help='task to perform',
choices=['play', 'eval', 'train'], default='train') choices=['play', 'eval', 'train'], default='train')
parser.add_argument('--rom', help='atari rom', required=True) parser.add_argument('--env', required=True,
help='either an atari rom file (that ends with .bin) or a gym atari environment name')
parser.add_argument('--algo', help='algorithm', parser.add_argument('--algo', help='algorithm',
choices=['DQN', 'Double', 'Dueling'], default='Double') choices=['DQN', 'Double', 'Dueling'], default='Double')
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
ROM_FILE = args.rom ENV_NAME = args.env
USE_GYM = not ENV_NAME.endswith('.bin')
IMAGE_CHANNEL = 3 if USE_GYM else 1
METHOD = args.algo METHOD = args.algo
# set num_actions # set num_actions
NUM_ACTIONS = AtariPlayer(ROM_FILE).action_space.n NUM_ACTIONS = get_player().action_space.n
logger.info("ROM: {}, Num Actions: {}".format(ROM_FILE, NUM_ACTIONS)) logger.info("ENV: {}, Num Actions: {}".format(ENV_NAME, NUM_ACTIONS))
if args.task != 'train': if args.task != 'train':
assert args.load is not None assert args.load is not None
...@@ -153,7 +172,7 @@ if __name__ == '__main__': ...@@ -153,7 +172,7 @@ if __name__ == '__main__':
else: else:
logger.set_logger_dir( logger.set_logger_dir(
os.path.join('train_log', 'DQN-{}'.format( os.path.join('train_log', 'DQN-{}'.format(
os.path.basename(ROM_FILE).split('.')[0]))) os.path.basename(ENV_NAME).split('.')[0])))
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = get_model_loader(args.load) config.session_init = get_model_loader(args.load)
......
...@@ -26,6 +26,7 @@ Double-DQN with nature paper setting runs at 60 batches (3840 trained frames, 24 ...@@ -26,6 +26,7 @@ Double-DQN with nature paper setting runs at 60 batches (3840 trained frames, 24
## How to use ## How to use
### With ALE (paper's setting):
Install [ALE](https://github.com/mgbellemare/Arcade-Learning-Environment) and gym. Install [ALE](https://github.com/mgbellemare/Arcade-Learning-Environment) and gym.
Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms), e.g.: Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms), e.g.:
...@@ -35,7 +36,7 @@ wget https://github.com/openai/atari-py/raw/master/atari_py/atari_roms/breakout. ...@@ -35,7 +36,7 @@ wget https://github.com/openai/atari-py/raw/master/atari_py/atari_roms/breakout.
Start Training: Start Training:
``` ```
./DQN.py --rom breakout.bin ./DQN.py --env breakout.bin
# use `--algo` to select other DQN algorithms. See `-h` for more options. # use `--algo` to select other DQN algorithms. See `-h` for more options.
``` ```
...@@ -43,7 +44,15 @@ Watch the agent play: ...@@ -43,7 +44,15 @@ Watch the agent play:
``` ```
# Download pretrained models or use one you trained: # Download pretrained models or use one you trained:
wget http://models.tensorpack.com/DeepQNetwork/DoubleDQN-Breakout.npz wget http://models.tensorpack.com/DeepQNetwork/DoubleDQN-Breakout.npz
./DQN.py --rom breakout.bin --task play --load DoubleDQN-Breakout.npz ./DQN.py --env breakout.bin --task play --load DoubleDQN-Breakout.npz
```
### With gym's Atari:
Install gym and atari_py.
```
./DQN.py --env BreakoutDeterministic-v4
``` ```
A3C code and models for Atari games in OpenAI Gym are released in [examples/A3C-Gym](../A3C-Gym) A3C code and models for Atari games in OpenAI Gym are released in [examples/A3C-Gym](../A3C-Gym)
...@@ -95,7 +95,7 @@ class AtariPlayer(gym.Env): ...@@ -95,7 +95,7 @@ class AtariPlayer(gym.Env):
self.action_space = spaces.Discrete(len(self.actions)) self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width), dtype=np.uint8) low=0, high=255, shape=(self.height, self.width, 1), dtype=np.uint8)
self._restart_episode() self._restart_episode()
def get_action_meanings(self): def get_action_meanings(self):
...@@ -110,7 +110,7 @@ class AtariPlayer(gym.Env): ...@@ -110,7 +110,7 @@ class AtariPlayer(gym.Env):
def _current_state(self): def _current_state(self):
""" """
:returns: a gray-scale (h, w) uint8 image :returns: a gray-scale (h, w, 1) uint8 image
""" """
ret = self._grab_raw_image() ret = self._grab_raw_image()
# max-pooled over the last screen # max-pooled over the last screen
...@@ -121,7 +121,7 @@ class AtariPlayer(gym.Env): ...@@ -121,7 +121,7 @@ class AtariPlayer(gym.Env):
cv2.waitKey(int(self.viz * 1000)) cv2.waitKey(int(self.viz * 1000))
ret = ret.astype('float32') ret = ret.astype('float32')
# 0.299,0.587.0.114. same as rgb2y in torch/image # 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
return ret.astype('uint8') # to save some memory return ret.astype('uint8') # to save some memory
def _restart_episode(self): def _restart_episode(self):
......
...@@ -135,7 +135,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -135,7 +135,7 @@ class ExpReplay(DataFlow, Callback):
Args: Args:
predictor_io_names (tuple of list of str): input/output names to predictor_io_names (tuple of list of str): input/output names to
predict Q value from state. predict Q value from state.
player (RLEnvironment): the player. player (gym.Env): the player.
state_shape (tuple): h, w, c state_shape (tuple): h, w, c
history_len (int): length of history frames to concat. Zero-filled history_len (int): length of history frames to concat. Zero-filled
initial frames. initial frames.
......
...@@ -30,7 +30,7 @@ def get_dorefa(bitW, bitA, bitG): ...@@ -30,7 +30,7 @@ def get_dorefa(bitW, bitA, bitG):
@tf.custom_gradient @tf.custom_gradient
def _sign(x): def _sign(x):
return tf.sign(x / E) * E, lambda dy: dy return tf.where(tf.equal(x, 0), tf.ones_like(x), tf.sign(x / E)) * E, lambda dy: dy
return _sign(x) return _sign(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment