Commit d95bf7a1 authored by Yuxin Wu's avatar Yuxin Wu

[A3C] support ALE in A3C

parent ea525564
......@@ -46,13 +46,20 @@ ENV_NAME = None
def get_player(train=False, dumpdir=None):
use_gym = not ENV_NAME.endswith(".bin")
if use_gym:
env = gym.make(ENV_NAME)
else:
from atari import AtariPlayer
env = AtariPlayer(ENV_NAME, frame_skip=4, viz=False,
live_lost_as_eoe=train, max_num_frames=60000,
grayscale=False)
if dumpdir:
env = gym.wrappers.Monitor(env, dumpdir, video_callable=lambda _: True)
env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
env = FrameStack(env, 4)
if train:
if train and use_gym:
env = LimitLength(env, 60000)
return env
......@@ -75,8 +82,9 @@ class Model(ModelDesc):
assert state.shape.rank == 5 # Batch, H, W, Channel, History
state = tf.transpose(state, [0, 1, 2, 4, 3]) # swap channel & history, to be compatible with old models
image = tf.reshape(state, [-1] + list(STATE_SHAPE[:2]) + [STATE_SHAPE[2] * FRAME_HISTORY])
image = tf.cast(image, tf.float32)
image = tf.cast(image, tf.float32) / 255.0
image = image / 255.0
with argscope(Conv2D, activation=tf.nn.relu):
l = Conv2D('conv0', image, 32, 5)
l = MaxPooling('pool0', l, 2)
......
......@@ -32,7 +32,8 @@ class AtariPlayer(gym.Env):
def __init__(self, rom_file, viz=0,
frame_skip=4, nullop_start=30,
live_lost_as_eoe=True, max_num_frames=0):
live_lost_as_eoe=True, max_num_frames=0,
grayscale=True):
"""
Args:
rom_file: path to the rom
......@@ -44,6 +45,7 @@ class AtariPlayer(gym.Env):
nullop_start: start with random number of null ops.
live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
max_num_frames: maximum number of frames per episode.
grayscale (bool): if True, return 2D image. Otherwise return HWC image.
"""
super(AtariPlayer, self).__init__()
if not os.path.isfile(rom_file) and '/' not in rom_file:
......@@ -91,8 +93,10 @@ class AtariPlayer(gym.Env):
self.nullop_start = nullop_start
self.action_space = spaces.Discrete(len(self.actions))
self.grayscale = grayscale
shape = (self.height, self.width) if grayscale else (self.height, self.width, 3)
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width), dtype=np.uint8)
low=0, high=255, shape=shape, dtype=np.uint8)
self._restart_episode()
def get_action_meanings(self):
......@@ -116,9 +120,9 @@ class AtariPlayer(gym.Env):
if isinstance(self.viz, float):
cv2.imshow(self.windowname, ret)
cv2.waitKey(int(self.viz * 1000))
ret = ret.astype('float32')
if self.grayscale:
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :]
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
return ret.astype('uint8') # to save some memory
def _restart_episode(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment