Commit ff40a873 authored by Yuxin Wu's avatar Yuxin Wu

major refactor RL

parent 40e6a223
......@@ -18,12 +18,11 @@ from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.concurrency import ensure_proc_terminate, subproc_call
from tensorpack.utils.stat import *
from tensorpack.predict import PredictConfig, get_predict_func, ParallelPredictWorker
from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPredictWorker
from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import *
from tensorpack.dataflow.dataset import AtariPlayer
from tensorpack.dataflow.RL import ExpReplay
from tensorpack.RL import AtariPlayer, ExpReplay
"""
Implement DQN in:
......
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import walk_packages
import importlib
import os
import os.path
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
_global_import(module_name)
......@@ -8,9 +8,10 @@ import time
import os
import cv2
from collections import deque
from ...utils import get_rng, logger
from ...utils.stat import StatCounter
from ..RL import RLEnvironment
from ..utils import get_rng, logger
from ..utils.stat import StatCounter
from .envbase import RLEnvironment
try:
from ale_python_interface import ALEInterface
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
from collections import deque
from .envbase import ProxyPlayer
__all__ = ['HistoryFramePlayer']
class HistoryFramePlayer(ProxyPlayer):
def __init__(self, player, hist_len):
super(HistoryFramePlayer, self).__init__(player)
self.history = deque(maxlen=hist_len)
s = self.player.current_state()
self.history.append(s)
def current_state(self):
assert len(self.history) != 0
diff_len = self.history.maxlen - len(self.history)
if diff_len == 0:
return np.concatenate(self.history, axis=2)
zeros = [np.zeros_like(self.history[0]) for k in range(diff_len)]
for k in self.history:
zeros.append(k)
return np.concatenate(zeros, axis=2)
def action(self, act):
r, isOver = self.player.action(act)
s = self.player.current_state()
self.history.append(s)
if isOver: # s would be a new episode
self.history.clear()
self.history.append(s)
return (r, isOver)
class AvoidNoOpPlayer(ProxyPlayer):
pass # TODO
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: envbase.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta
from collections import defaultdict
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer']
class RLEnvironment(object):
__meta__ = ABCMeta
def __init__(self):
self.reset_stat()
@abstractmethod
def current_state(self):
"""
Observe, return a state representation
"""
@abstractmethod
def action(self, act):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
@abstractmethod
def get_stat(self):
"""
return a dict of statistics (e.g., score) after running for a while
"""
def reset_stat(self):
""" reset the statistics counter"""
self.stats = defaultdict(list)
class NaiveRLEnvironment(RLEnvironment):
""" for testing only"""
def __init__(self):
self.k = 0
def current_state(self):
self.k += 1
return self.k
def action(self, act):
self.k = act
return (self.k, self.k > 10)
class ProxyPlayer(RLEnvironment):
def __init__(self, player):
self.player = player
def get_stat(self):
return self.player.get_stat()
def reset_stat(self):
self.player.reset_stat()
def current_state(self):
return self.player.current_state()
def action(self, act):
return self.player.action(act)
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: RL.py
# File: expreplay.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta
import random
import numpy as np
from collections import deque, namedtuple, defaultdict
from collections import deque, namedtuple
from tqdm import tqdm
import cv2
import six
from .base import DataFlow
from tensorpack.utils import *
from tensorpack.callbacks.base import Callback
from ..dataflow import DataFlow
from ..utils import *
from ..callbacks.base import Callback
"""
Implement RL-related data preprocessing
"""
__all__ = ['ExpReplay', 'RLEnvironment', 'NaiveRLEnvironment', 'HistoryFramePlayer']
__all__ = ['ExpReplay']
Experience = namedtuple('Experience',
['state', 'action', 'reward', 'isOver'])
class RLEnvironment(object):
__meta__ = ABCMeta
def __init__(self):
self.reset_stat()
@abstractmethod
def current_state(self):
"""
Observe, return a state representation
"""
@abstractmethod
def action(self, act):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
@abstractmethod
def get_stat(self):
"""
return a dict of statistics (e.g., score) after running for a while
"""
def reset_stat(self):
""" reset the statistics counter"""
self.stats = defaultdict(list)
class NaiveRLEnvironment(RLEnvironment):
""" for testing only"""
def __init__(self):
self.k = 0
def current_state(self):
self.k += 1
return self.k
def action(self, act):
self.k = act
return (self.k, self.k > 10)
class ProxyPlayer(RLEnvironment):
def __init__(self, player):
self.player = player
def get_stat(self):
return self.player.get_stat()
def reset_stat(self):
self.player.reset_stat()
def current_state(self):
return self.player.current_state()
def action(self, act):
return self.player.action(act)
class HistoryFramePlayer(ProxyPlayer):
def __init__(self, player, hist_len):
super(HistoryFramePlayer, self).__init__(player)
self.history = deque(maxlen=hist_len)
s = self.player.current_state()
self.history.append(s)
def current_state(self):
assert len(self.history) != 0
diff_len = self.history.maxlen - len(self.history)
if diff_len == 0:
return np.concatenate(self.history, axis=2)
zeros = [np.zeros_like(self.history[0]) for k in range(diff_len)]
for k in self.history:
zeros.append(k)
return np.concatenate(zeros, axis=2)
def action(self, act):
r, isOver = self.player.action(act)
s = self.player.current_state()
self.history.append(s)
if isOver: # s would be a new episode
self.history.clear()
self.history.append(s)
return (r, isOver)
class ExpReplay(DataFlow, Callback):
"""
Implement experience replay in the paper
......@@ -182,6 +90,7 @@ class ExpReplay(DataFlow, Callback):
while True:
batch_exp = [self.sample_one() for _ in range(self.batch_size)]
#import cv2
#def view_state(state, next_state):
#""" for debugging state representation"""
#r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
......@@ -253,7 +162,7 @@ class ExpReplay(DataFlow, Callback):
self.player.reset_stat()
if __name__ == '__main__':
from tensorpack.dataflow.dataset import AtariPlayer
from .atari import AtariPlayer
import sys
predictor = lambda x: np.array([1,1,1,1])
predictor.initialized = False
......
......@@ -9,9 +9,9 @@ import os.path
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
del globals()[name]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment