Commit 704bee73 authored by Yuxin Wu's avatar Yuxin Wu

move atari driver

parent 17687d5c
...@@ -2,3 +2,4 @@ mnist_data ...@@ -2,3 +2,4 @@ mnist_data
cifar10_data cifar10_data
svhn_data svhn_data
ilsvrc_metadata ilsvrc_metadata
bsds500_data
...@@ -2,19 +2,19 @@ ...@@ -2,19 +2,19 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: atari.py # File: atari.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ale_python_interface import ALEInterface from ale_python_interface import ALEInterface
import numpy as np import numpy as np
import time import time
import os import os
import cv2 import cv2
from .utils import get_rng from collections import deque
from ...utils import get_rng
__all__ = ['AtariDriver'] __all__ = ['AtariDriver', 'AtariPlayer']
class AtariDriver(object): class AtariDriver(object):
""" """
A driver for atari games. A wrapper for atari emulator.
""" """
def __init__(self, rom_file, frame_skip=1, viz=0): def __init__(self, rom_file, frame_skip=1, viz=0):
""" """
...@@ -25,7 +25,7 @@ class AtariDriver(object): ...@@ -25,7 +25,7 @@ class AtariDriver(object):
self.ale = ALEInterface() self.ale = ALEInterface()
self.rng = get_rng(self) self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(999)) self.ale.setInt("random_seed", self.rng.randint(self.rng.randint(0, 1000)))
self.ale.setInt("frame_skip", frame_skip) self.ale.setInt("frame_skip", frame_skip)
self.ale.loadROM(rom_file) self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims() self.width, self.height = self.ale.getScreenDims()
...@@ -42,7 +42,7 @@ class AtariDriver(object): ...@@ -42,7 +42,7 @@ class AtariDriver(object):
def _grab_raw_image(self): def _grab_raw_image(self):
""" """
:returns: a 3-channel image :returns: the current 3-channel image
""" """
m = np.zeros(self.height * self.width * 3, dtype=np.uint8) m = np.zeros(self.height * self.width * 3, dtype=np.uint8)
self.ale.getScreenRGB(m) self.ale.getScreenRGB(m)
...@@ -50,7 +50,7 @@ class AtariDriver(object): ...@@ -50,7 +50,7 @@ class AtariDriver(object):
def grab_image(self): def grab_image(self):
""" """
:returns: a gray-scale image, maximum over the last :returns: a gray-scale image, max-pooled over the last frame.
""" """
now = self._grab_raw_image() now = self._grab_raw_image()
ret = np.maximum(now, self.last_image) ret = np.maximum(now, self.last_image)
...@@ -82,6 +82,68 @@ class AtariDriver(object): ...@@ -82,6 +82,68 @@ class AtariDriver(object):
self._reset() self._reset()
return (s, r, isOver) return (s, r, isOver)
class AtariPlayer(object):
""" An Atari game player with limited memory and FPS"""
def __init__(self, driver, hist_len=4, action_repeat=4, image_shape=(84,84)):
"""
:param driver: an `AtariDriver` instance.
:param hist_len: history(memory) length
:param action_repeat: repeat each action `action_repeat` times and skip those frames
:param image_shape: the shape of the observed image
"""
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.last_act = 0
self.frames = deque(maxlen=hist_len)
self.restart()
def restart(self):
"""
Restart the game and populate frames with the beginning frame
"""
self.frames.clear()
s = self.driver.grab_image()
s = cv2.resize(s, self.image_shape)
for _ in range(self.hist_len):
self.frames.append(s)
def current_state(self):
"""
Return a current state of shape `image_shape + (hist_len,)`
"""
return self._build_state()
def action(self, act):
"""
Perform an action
:param act: index of the action
:returns: (new_frame, reward, isOver)
"""
self.last_act = act
return self._grab()
def _build_state(self):
assert len(self.frames) == self.hist_len
m = np.array(self.frames)
m = m.transpose([1,2,0])
return m
def _grab(self):
""" if isOver==True, current_state will return the new episode
"""
totr = 0
for k in range(self.action_repeat):
s, r, isOver = self.driver.next(self.last_act)
totr += r
if isOver:
break
s = cv2.resize(s, self.image_shape)
self.frames.append(s)
if isOver:
self.restart()
return (s, totr, isOver)
if __name__ == '__main__': if __name__ == '__main__':
a = AtariDriver('breakout.bin', viz=True) a = AtariDriver('breakout.bin', viz=True)
num = a.get_num_actions() num = a.get_num_actions()
......
...@@ -55,7 +55,6 @@ def logSoftmax(x): ...@@ -55,7 +55,6 @@ def logSoftmax(x):
logprob = z - tf.log(tf.reduce_sum(tf.exp(z), 1, keep_dims=True)) logprob = z - tf.log(tf.reduce_sum(tf.exp(z), 1, keep_dims=True))
return logprob return logprob
def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'): def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'):
""" """
The class-balanced cross entropy loss for binary classification, The class-balanced cross entropy loss for binary classification,
...@@ -80,3 +79,8 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l ...@@ -80,3 +79,8 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
cost = tf.reduce_mean(cost, name=name) cost = tf.reduce_mean(cost, name=name)
return cost return cost
def print_stat(x):
""" a simple print op.
Use it like: x = print_stat(x)
"""
return tf.Print(x, [tf.reduce_mean(x), x], summarize=20)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment