Commit de6d5502 authored by Yuxin Wu's avatar Yuxin Wu

rl environment base

parent 97dd6c5c
......@@ -168,7 +168,7 @@ def play_model(model_path, romfile):
act = 1
que.append(act)
print(act)
_, reward, isOver = player.action(act)
reward, isOver = player.action(act)
tot_reward += reward
if isOver:
print("Total:", tot_reward)
......@@ -210,7 +210,7 @@ def eval_model_multiprocess(model_path, romfile):
act = 1
que.append(act)
#print(act)
_, reward, isOver = player.action(act)
reward, isOver = player.action(act)
tot_reward += reward
if isOver:
self.outq.put(tot_reward)
......
......@@ -61,7 +61,7 @@ class AtariExpReplay(DataFlow):
act = self.rng.choice(range(self.num_actions))
else:
act = np.argmax(self.predictor(old_s)) # TODO race condition in session?
_, reward, isOver = self.player.action(act)
reward, isOver = self.player.action(act)
reward = np.clip(reward, -1, 2)
s = self.player.current_state()
......
......@@ -9,6 +9,7 @@ import os
import cv2
from collections import deque
from ...utils import get_rng
from . import RLEnvironment
__all__ = ['AtariDriver', 'AtariPlayer']
......@@ -86,7 +87,7 @@ class AtariDriver(object):
self._reset()
return (s, r, isOver)
class AtariPlayer(object):
class AtariPlayer(RLEnvironment):
""" An Atari game player with limited memory and FPS"""
def __init__(self, driver, hist_len=4, action_repeat=4, image_shape=(84,84)):
"""
......@@ -125,7 +126,7 @@ class AtariPlayer(object):
:returns: (new_frame, reward, isOver)
"""
self.last_act = act
return self._grab()
return self._observe()
def _build_state(self):
assert len(self.frames) == self.hist_len
......@@ -133,7 +134,7 @@ class AtariPlayer(object):
m = m.transpose([1,2,0])
return m
def _grab(self):
def _observe(self):
""" if isOver==True, current_state will return the new episode
"""
totr = 0
......@@ -146,7 +147,7 @@ class AtariPlayer(object):
self.frames.append(s)
if isOver:
self.restart()
return (s, totr, isOver)
return (totr, isOver)
if __name__ == '__main__':
a = AtariDriver('breakout.bin', viz=True)
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: rlenv.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta
__all__ = ['RLEnvironment']
class RLEnvironment(object):
__meta__ = ABCMeta
@abstractmethod
def current_state(self):
"""
Observe, return a state representation
"""
@abstractmethod
def action(self, act):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment