Commit 704bee73 authored by Yuxin Wu's avatar Yuxin Wu

move atari driver

parent 17687d5c
......@@ -2,3 +2,4 @@ mnist_data
cifar10_data
svhn_data
ilsvrc_metadata
bsds500_data
......@@ -2,19 +2,19 @@
# -*- coding: utf-8 -*-
# File: atari.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ale_python_interface import ALEInterface
import numpy as np
import time
import os
import cv2
from .utils import get_rng
from collections import deque
from ...utils import get_rng
__all__ = ['AtariDriver']
__all__ = ['AtariDriver', 'AtariPlayer']
class AtariDriver(object):
"""
A driver for atari games.
A wrapper for atari emulator.
"""
def __init__(self, rom_file, frame_skip=1, viz=0):
"""
......@@ -25,7 +25,7 @@ class AtariDriver(object):
self.ale = ALEInterface()
self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(999))
self.ale.setInt("random_seed", self.rng.randint(self.rng.randint(0, 1000)))
self.ale.setInt("frame_skip", frame_skip)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
......@@ -42,7 +42,7 @@ class AtariDriver(object):
def _grab_raw_image(self):
"""
:returns: a 3-channel image
:returns: the current 3-channel image
"""
m = np.zeros(self.height * self.width * 3, dtype=np.uint8)
self.ale.getScreenRGB(m)
......@@ -50,7 +50,7 @@ class AtariDriver(object):
def grab_image(self):
"""
:returns: a gray-scale image, maximum over the last
:returns: a gray-scale image, max-pooled over the last frame.
"""
now = self._grab_raw_image()
ret = np.maximum(now, self.last_image)
......@@ -82,6 +82,68 @@ class AtariDriver(object):
self._reset()
return (s, r, isOver)
class AtariPlayer(object):
""" An Atari game player with limited memory and FPS"""
def __init__(self, driver, hist_len=4, action_repeat=4, image_shape=(84,84)):
"""
:param driver: an `AtariDriver` instance.
:param hist_len: history(memory) length
:param action_repeat: repeat each action `action_repeat` times and skip those frames
:param image_shape: the shape of the observed image
"""
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.last_act = 0
self.frames = deque(maxlen=hist_len)
self.restart()
def restart(self):
"""
Restart the game and populate frames with the beginning frame
"""
self.frames.clear()
s = self.driver.grab_image()
s = cv2.resize(s, self.image_shape)
for _ in range(self.hist_len):
self.frames.append(s)
def current_state(self):
"""
Return a current state of shape `image_shape + (hist_len,)`
"""
return self._build_state()
def action(self, act):
"""
Perform an action
:param act: index of the action
:returns: (new_frame, reward, isOver)
"""
self.last_act = act
return self._grab()
def _build_state(self):
assert len(self.frames) == self.hist_len
m = np.array(self.frames)
m = m.transpose([1,2,0])
return m
def _grab(self):
""" if isOver==True, current_state will return the new episode
"""
totr = 0
for k in range(self.action_repeat):
s, r, isOver = self.driver.next(self.last_act)
totr += r
if isOver:
break
s = cv2.resize(s, self.image_shape)
self.frames.append(s)
if isOver:
self.restart()
return (s, totr, isOver)
if __name__ == '__main__':
a = AtariDriver('breakout.bin', viz=True)
num = a.get_num_actions()
......
......@@ -55,7 +55,6 @@ def logSoftmax(x):
logprob = z - tf.log(tf.reduce_sum(tf.exp(z), 1, keep_dims=True))
return logprob
def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss for binary classification,
......@@ -80,3 +79,8 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
cost = tf.reduce_mean(cost, name=name)
return cost
def print_stat(x):
""" a simple print op.
Use it like: x = print_stat(x)
"""
return tf.Print(x, [tf.reduce_mean(x), x], summarize=20)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment