Commit 48115f68 authored by Yuxin Wu's avatar Yuxin Wu

move rlenv to RL.py

parent 90a14aa4
...@@ -3,36 +3,59 @@ ...@@ -3,36 +3,59 @@
# File: RL.py # File: RL.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import DataFlow from abc import abstractmethod, ABCMeta
from tensorpack.utils import *
from tensorpack.callbacks.base import Callback
from tqdm import tqdm
import random import random
import numpy as np import numpy as np
import cv2
from collections import deque, namedtuple from collections import deque, namedtuple
from tqdm import tqdm
import cv2
from .base import DataFlow
from tensorpack.utils import *
from tensorpack.callbacks.base import Callback
""" """
Implement RL-related data preprocessing Implement RL-related data preprocessing
""" """
__all__ = ['ExpReplay'] __all__ = ['ExpReplay', 'RLEnvironment', 'NaiveRLEnvironment']
Experience = namedtuple('Experience', Experience = namedtuple('Experience',
['state', 'action', 'reward', 'next', 'isOver']) ['state', 'action', 'reward', 'next', 'isOver'])
def view_state(state): class RLEnvironment(object):
# for debug __meta__ = ABCMeta
r = np.concatenate([state[:,:,k] for k in range(state.shape[2])], axis=1)
print r.shape @abstractmethod
cv2.imshow("state", r) def current_state(self):
cv2.waitKey() """
Observe, return a state representation
"""
@abstractmethod
def action(self, act):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
class NaiveRLEnvironment(RLEnvironment):
""" for testing only"""
def __init__(self):
self.k = 0
def current_state(self):
self.k += 1
return self.k
def action(self, act):
self.k = act
return (self.k, self.k > 10)
class ExpReplay(DataFlow, Callback): class ExpReplay(DataFlow, Callback):
""" """
Implement experience replay. Implement experience replay in the paper
`Human-level control through deep reinforcement learning`.
""" """
def __init__(self, def __init__(self,
predictor, predictor,
...@@ -78,6 +101,12 @@ class ExpReplay(DataFlow, Callback): ...@@ -78,6 +101,12 @@ class ExpReplay(DataFlow, Callback):
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1]) reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
s = self.player.current_state() s = self.player.current_state()
#def view_state(state):
#""" for debug state representation"""
#r = np.concatenate([state[:,:,k] for k in range(state.shape[2])], axis=1)
#print r.shape
#cv2.imshow("state", r)
#cv2.waitKey()
#print act, reward #print act, reward
#view_state(s) #view_state(s)
...@@ -116,6 +145,8 @@ class ExpReplay(DataFlow, Callback): ...@@ -116,6 +145,8 @@ class ExpReplay(DataFlow, Callback):
self.exploration -= self.exploration_epoch_anneal self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration)) logger.info("Exploration changed to {}".format(self.exploration))
if __name__ == '__main__': if __name__ == '__main__':
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
predictor = lambda x: np.array([1,1,1,1]) predictor = lambda x: np.array([1,1,1,1])
......
...@@ -9,7 +9,7 @@ import os ...@@ -9,7 +9,7 @@ import os
import cv2 import cv2
from collections import deque from collections import deque
from ...utils import get_rng, logger from ...utils import get_rng, logger
from .rlenv import RLEnvironment from ..RL import RLEnvironment
try: try:
from ale_python_interface import ALEInterface from ale_python_interface import ALEInterface
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: rlenv.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta
__all__ = ['RLEnvironment', 'NaiveRLEnvironment']
class RLEnvironment(object):
__meta__ = ABCMeta
@abstractmethod
def current_state(self):
"""
Observe, return a state representation
"""
@abstractmethod
def action(self, act):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
class NaiveRLEnvironment(RLEnvironment):
def __init__(self):
self.k = 0
def current_state(self):
self.k += 1
return self.k
def action(self, act):
self.k = act
return (self.k, self.k > 10)
...@@ -22,6 +22,7 @@ def get_default_sess_config(mem_fraction=0.9): ...@@ -22,6 +22,7 @@ def get_default_sess_config(mem_fraction=0.9):
conf = tf.ConfigProto() conf = tf.ConfigProto()
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
conf.gpu_options.allocator_type = 'BFC' conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True
conf.allow_soft_placement = True conf.allow_soft_placement = True
#conf.log_device_placement = True #conf.log_device_placement = True
return conf return conf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment