Commit ad5321a6 authored by Yuxin Wu's avatar Yuxin Wu

[RL] use more general MapState instead of WarpFrame

parent 2c429763
...@@ -32,7 +32,7 @@ import gym ...@@ -32,7 +32,7 @@ import gym
from simulator import * from simulator import *
from common import (Evaluator, eval_model_multithread, from common import (Evaluator, eval_model_multithread,
play_one_episode, play_n_episodes) play_one_episode, play_n_episodes)
from atari_wrapper import WarpFrame, FrameStack, FireResetEnv, LimitLength from atari_wrapper import MapState, FrameStack, FireResetEnv, LimitLength
if six.PY3: if six.PY3:
from concurrent import futures from concurrent import futures
...@@ -64,7 +64,7 @@ def get_player(train=False, dumpdir=None): ...@@ -64,7 +64,7 @@ def get_player(train=False, dumpdir=None):
if dumpdir: if dumpdir:
env = gym.wrappers.Monitor(env, dumpdir) env = gym.wrappers.Monitor(env, dumpdir)
env = FireResetEnv(env) env = FireResetEnv(env)
env = WarpFrame(env, IMAGE_SIZE) env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
env = FrameStack(env, 4) env = FrameStack(env, 4)
if train: if train:
env = LimitLength(env, 60000) env = LimitLength(env, 60000)
......
...@@ -15,6 +15,7 @@ import subprocess ...@@ -15,6 +15,7 @@ import subprocess
import multiprocessing import multiprocessing
import threading import threading
from collections import deque from collections import deque
import cv2
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
...@@ -23,7 +24,7 @@ import tensorflow as tf ...@@ -23,7 +24,7 @@ import tensorflow as tf
from DQNModel import Model as DQNModel from DQNModel import Model as DQNModel
from common import Evaluator, eval_model_multithread, play_n_episodes from common import Evaluator, eval_model_multithread, play_n_episodes
from atari_wrapper import FrameStack, WarpFrame, FireResetEnv from atari_wrapper import FrameStack, MapState, FireResetEnv
from expreplay import ExpReplay from expreplay import ExpReplay
from atari import AtariPlayer from atari import AtariPlayer
...@@ -50,7 +51,7 @@ def get_player(viz=False, train=False): ...@@ -50,7 +51,7 @@ def get_player(viz=False, train=False):
env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz, env = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, viz=viz,
live_lost_as_eoe=train, max_num_frames=30000) live_lost_as_eoe=train, max_num_frames=30000)
env = FireResetEnv(env) env = FireResetEnv(env)
env = WarpFrame(env, IMAGE_SIZE) env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
if not train: if not train:
# in training, history is taken care of in expreplay buffer # in training, history is taken care of in expreplay buffer
env = FrameStack(env, FRAME_HISTORY) env = FrameStack(env, FRAME_HISTORY)
......
...@@ -16,18 +16,13 @@ https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers. ...@@ -16,18 +16,13 @@ https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.
""" """
class WarpFrame(gym.ObservationWrapper): class MapState(gym.ObservationWrapper):
def __init__(self, env, shape): def __init__(self, env, map_func):
gym.ObservationWrapper.__init__(self, env) gym.ObservationWrapper.__init__(self, env)
self.shape = shape self._func = map_func
obs = env.observation_space
assert isinstance(obs, spaces.Box)
chan = 1 if len(obs.shape) == 2 else obs.shape[2]
shape3d = shape if chan == 1 else shape + (chan,)
self.observation_space = spaces.Box(low=0, high=255, shape=shape3d)
def _observation(self, obs): def _observation(self, obs):
return cv2.resize(obs, self.shape) return self._func(obs)
class FrameStack(gym.Wrapper): class FrameStack(gym.Wrapper):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment