Commit de6d5502 authored by Yuxin Wu's avatar Yuxin Wu

rl environment base

parent 97dd6c5c
...@@ -168,7 +168,7 @@ def play_model(model_path, romfile): ...@@ -168,7 +168,7 @@ def play_model(model_path, romfile):
act = 1 act = 1
que.append(act) que.append(act)
print(act) print(act)
_, reward, isOver = player.action(act) reward, isOver = player.action(act)
tot_reward += reward tot_reward += reward
if isOver: if isOver:
print("Total:", tot_reward) print("Total:", tot_reward)
...@@ -210,7 +210,7 @@ def eval_model_multiprocess(model_path, romfile): ...@@ -210,7 +210,7 @@ def eval_model_multiprocess(model_path, romfile):
act = 1 act = 1
que.append(act) que.append(act)
#print(act) #print(act)
_, reward, isOver = player.action(act) reward, isOver = player.action(act)
tot_reward += reward tot_reward += reward
if isOver: if isOver:
self.outq.put(tot_reward) self.outq.put(tot_reward)
......
...@@ -61,7 +61,7 @@ class AtariExpReplay(DataFlow): ...@@ -61,7 +61,7 @@ class AtariExpReplay(DataFlow):
act = self.rng.choice(range(self.num_actions)) act = self.rng.choice(range(self.num_actions))
else: else:
act = np.argmax(self.predictor(old_s)) # TODO race condition in session? act = np.argmax(self.predictor(old_s)) # TODO race condition in session?
_, reward, isOver = self.player.action(act) reward, isOver = self.player.action(act)
reward = np.clip(reward, -1, 2) reward = np.clip(reward, -1, 2)
s = self.player.current_state() s = self.player.current_state()
......
...@@ -9,6 +9,7 @@ import os ...@@ -9,6 +9,7 @@ import os
import cv2 import cv2
from collections import deque from collections import deque
from ...utils import get_rng from ...utils import get_rng
from . import RLEnvironment
__all__ = ['AtariDriver', 'AtariPlayer'] __all__ = ['AtariDriver', 'AtariPlayer']
...@@ -86,7 +87,7 @@ class AtariDriver(object): ...@@ -86,7 +87,7 @@ class AtariDriver(object):
self._reset() self._reset()
return (s, r, isOver) return (s, r, isOver)
class AtariPlayer(object): class AtariPlayer(RLEnvironment):
""" An Atari game player with limited memory and FPS""" """ An Atari game player with limited memory and FPS"""
def __init__(self, driver, hist_len=4, action_repeat=4, image_shape=(84,84)): def __init__(self, driver, hist_len=4, action_repeat=4, image_shape=(84,84)):
""" """
...@@ -125,7 +126,7 @@ class AtariPlayer(object): ...@@ -125,7 +126,7 @@ class AtariPlayer(object):
:returns: (new_frame, reward, isOver) :returns: (new_frame, reward, isOver)
""" """
self.last_act = act self.last_act = act
return self._grab() return self._observe()
def _build_state(self): def _build_state(self):
assert len(self.frames) == self.hist_len assert len(self.frames) == self.hist_len
...@@ -133,7 +134,7 @@ class AtariPlayer(object): ...@@ -133,7 +134,7 @@ class AtariPlayer(object):
m = m.transpose([1,2,0]) m = m.transpose([1,2,0])
return m return m
def _grab(self): def _observe(self):
""" if isOver==True, current_state will return the new episode """ if isOver==True, current_state will return the new episode
""" """
totr = 0 totr = 0
...@@ -146,7 +147,7 @@ class AtariPlayer(object): ...@@ -146,7 +147,7 @@ class AtariPlayer(object):
self.frames.append(s) self.frames.append(s)
if isOver: if isOver:
self.restart() self.restart()
return (s, totr, isOver) return (totr, isOver)
if __name__ == '__main__': if __name__ == '__main__':
a = AtariDriver('breakout.bin', viz=True) a = AtariDriver('breakout.bin', viz=True)
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: rlenv.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta
__all__ = ['RLEnvironment']
class RLEnvironment(object):
__meta__ = ABCMeta
@abstractmethod
def current_state(self):
"""
Observe, return a state representation
"""
@abstractmethod
def action(self, act):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment