Commit 48115f68 authored by Yuxin Wu's avatar Yuxin Wu

move rlenv to RL.py

parent 90a14aa4
......@@ -3,36 +3,59 @@
# File: RL.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import DataFlow
from tensorpack.utils import *
from tensorpack.callbacks.base import Callback
from tqdm import tqdm
from abc import abstractmethod, ABCMeta
import random
import numpy as np
import cv2
from collections import deque, namedtuple
from tqdm import tqdm
import cv2
from .base import DataFlow
from tensorpack.utils import *
from tensorpack.callbacks.base import Callback
"""
Implement RL-related data preprocessing
"""
__all__ = ['ExpReplay']
__all__ = ['ExpReplay', 'RLEnvironment', 'NaiveRLEnvironment']
Experience = namedtuple('Experience',
['state', 'action', 'reward', 'next', 'isOver'])
def view_state(state):
# for debug
r = np.concatenate([state[:,:,k] for k in range(state.shape[2])], axis=1)
print r.shape
cv2.imshow("state", r)
cv2.waitKey()
class RLEnvironment(object):
__meta__ = ABCMeta
@abstractmethod
def current_state(self):
"""
Observe, return a state representation
"""
@abstractmethod
def action(self, act):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
class NaiveRLEnvironment(RLEnvironment):
""" for testing only"""
def __init__(self):
self.k = 0
def current_state(self):
self.k += 1
return self.k
def action(self, act):
self.k = act
return (self.k, self.k > 10)
class ExpReplay(DataFlow, Callback):
"""
Implement experience replay.
Implement experience replay in the paper
`Human-level control through deep reinforcement learning`.
"""
def __init__(self,
predictor,
......@@ -78,6 +101,12 @@ class ExpReplay(DataFlow, Callback):
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
s = self.player.current_state()
#def view_state(state):
#""" for debug state representation"""
#r = np.concatenate([state[:,:,k] for k in range(state.shape[2])], axis=1)
#print r.shape
#cv2.imshow("state", r)
#cv2.waitKey()
#print act, reward
#view_state(s)
......@@ -116,6 +145,8 @@ class ExpReplay(DataFlow, Callback):
self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration))
if __name__ == '__main__':
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
predictor = lambda x: np.array([1,1,1,1])
......
......@@ -9,7 +9,7 @@ import os
import cv2
from collections import deque
from ...utils import get_rng, logger
from .rlenv import RLEnvironment
from ..RL import RLEnvironment
try:
from ale_python_interface import ALEInterface
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: rlenv.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta
__all__ = ['RLEnvironment', 'NaiveRLEnvironment']
class RLEnvironment(object):
__meta__ = ABCMeta
@abstractmethod
def current_state(self):
"""
Observe, return a state representation
"""
@abstractmethod
def action(self, act):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
class NaiveRLEnvironment(RLEnvironment):
def __init__(self):
self.k = 0
def current_state(self):
self.k += 1
return self.k
def action(self, act):
self.k = act
return (self.k, self.k > 10)
......@@ -22,6 +22,7 @@ def get_default_sess_config(mem_fraction=0.9):
conf = tf.ConfigProto()
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True
conf.allow_soft_placement = True
#conf.log_device_placement = True
return conf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment