Commit ad5321a6 authored by Yuxin Wu's avatar Yuxin Wu

[RL] use more general MapState instead of WarpFrame

parent 2c429763
......@@ -32,7 +32,7 @@ import gym
from simulator import *
from common import (Evaluator, eval_model_multithread,
play_one_episode, play_n_episodes)
from atari_wrapper import WarpFrame, FrameStack, FireResetEnv, LimitLength
from atari_wrapper import MapState, FrameStack, FireResetEnv, LimitLength
if six.PY3:
from concurrent import futures
......@@ -64,7 +64,7 @@ def get_player(train=False, dumpdir=None):
if dumpdir:
env = gym.wrappers.Monitor(env, dumpdir)
env = FireResetEnv(env)
env = WarpFrame(env, IMAGE_SIZE)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
env = FrameStack(env, 4)
if train:
env = LimitLength(env, 60000)
......
......@@ -15,6 +15,7 @@ import subprocess
import multiprocessing
import threading
from collections import deque
import cv2
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
......@@ -23,7 +24,7 @@ import tensorflow as tf
from DQNModel import Model as DQNModel
from common import Evaluator, eval_model_multithread, play_n_episodes
from atari_wrapper import FrameStack, WarpFrame, FireResetEnv
from atari_wrapper import FrameStack, MapState, FireResetEnv
from expreplay import ExpReplay
from atari import AtariPlayer
......@@ -50,7 +51,7 @@ def get_player(viz=False, train=False):
env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz,
live_lost_as_eoe=train, max_num_frames=30000)
env = FireResetEnv(env)
env = WarpFrame(env, IMAGE_SIZE)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
if not train:
# in training, history is taken care of in expreplay buffer
env = FrameStack(env, FRAME_HISTORY)
......
......@@ -16,18 +16,13 @@ https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.
"""
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, shape):
class MapState(gym.ObservationWrapper):
def __init__(self, env, map_func):
gym.ObservationWrapper.__init__(self, env)
self.shape = shape
obs = env.observation_space
assert isinstance(obs, spaces.Box)
chan = 1 if len(obs.shape) == 2 else obs.shape[2]
shape3d = shape if chan == 1 else shape + (chan,)
self.observation_space = spaces.Box(low=0, high=255, shape=shape3d)
self._func = map_func
def _observation(self, obs):
return cv2.resize(obs, self.shape)
return self._func(obs)
class FrameStack(gym.Wrapper):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment