Commit 961b0ee4 authored by Yuxin Wu's avatar Yuxin Wu

move exp_replay

parent fc0b965a
...@@ -245,15 +245,17 @@ def get_config(romfile): ...@@ -245,15 +245,17 @@ def get_config(romfile):
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = driver.get_num_actions() NUM_ACTIONS = driver.get_num_actions()
dataset_train = AtariExpReplay( dataset_train = ExpReplay(
predictor=current_predictor, predictor=current_predictor,
player=AtariPlayer( player=AtariPlayer(
driver, hist_len=FRAME_HISTORY, driver, hist_len=FRAME_HISTORY,
action_repeat=ACTION_REPEAT), action_repeat=ACTION_REPEAT),
num_actions=NUM_ACTIONS,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
populate_size=INIT_MEMORY_SIZE, populate_size=INIT_MEMORY_SIZE,
exploration=INIT_EXPLORATION) exploration=INIT_EXPLORATION,
reward_clip=(-1, 2))
lr = tf.Variable(0.0025, trainable=False, name='learning_rate') lr = tf.Variable(0.0025, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy import numpy # avoid https://github.com/tensorflow/tensorflow/issues/2034
import cv2 # fix https://github.com/tensorflow/tensorflow/issues/1924 import cv2 # avoid https://github.com/tensorflow/tensorflow/issues/1924
from . import models from . import models
from . import train from . import train
from . import utils from . import utils
......
#!/usr/bin/env python2 #!/usr/bin/env python2
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: exp_replay.py # File: RL.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from tensorpack.dataflow import * from .base import DataFlow
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
from tensorpack.utils import * from tensorpack.utils import *
from tqdm import tqdm from tqdm import tqdm
import random import random
import numpy as np import numpy as np
import cv2 import cv2
from collections import deque, namedtuple from collections import deque, namedtuple
"""
Implement RL-related data preprocessing
"""
__all__ = ['ExpReplay']
Experience = namedtuple('Experience', Experience = namedtuple('Experience',
['state', 'action', 'reward', 'next', 'isOver']) ['state', 'action', 'reward', 'next', 'isOver'])
def view_state(state): def view_state(state):
# for debug
r = np.concatenate([state[:,:,k] for k in range(state.shape[2])], axis=1) r = np.concatenate([state[:,:,k] for k in range(state.shape[2])], axis=1)
print r.shape print r.shape
cv2.imshow("state", r) cv2.imshow("state", r)
cv2.waitKey() cv2.waitKey()
class AtariExpReplay(DataFlow): class ExpReplay(DataFlow):
""" """
Implement experience replay Implement experience replay.
""" """
def __init__(self, def __init__(self,
predictor, predictor,
player, player,
num_actions,
memory_size=1e6, memory_size=1e6,
batch_size=32, batch_size=32,
populate_size=50000, populate_size=50000,
exploration=1): exploration=1,
reward_clip=None):
""" """
:param predictor: callabale. called with a state, return a distribution :param predictor: callabale. called with a state, return a distribution
:param player: a `RLEnvironment`
""" """
for k, v in locals().items(): for k, v in locals().items():
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
self.num_actions = self.player.driver.get_num_actions()
logger.info("Number of Legal actions: {}".format(self.num_actions)) logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=memory_size) self.mem = deque(maxlen=memory_size)
self.rng = get_rng(self) self.rng = get_rng(self)
...@@ -62,7 +70,8 @@ class AtariExpReplay(DataFlow): ...@@ -62,7 +70,8 @@ class AtariExpReplay(DataFlow):
else: else:
act = np.argmax(self.predictor(old_s)) # TODO race condition in session? act = np.argmax(self.predictor(old_s)) # TODO race condition in session?
reward, isOver = self.player.action(act) reward, isOver = self.player.action(act)
reward = np.clip(reward, -1, 2) if self.reward_clip:
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
s = self.player.current_state() s = self.player.current_state()
#print act, reward #print act, reward
...@@ -94,6 +103,7 @@ class AtariExpReplay(DataFlow): ...@@ -94,6 +103,7 @@ class AtariExpReplay(DataFlow):
return [state, action, reward, next_state, isOver] return [state, action, reward, next_state, isOver]
if __name__ == '__main__': if __name__ == '__main__':
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
predictor = lambda x: np.array([1,1,1,1]) predictor = lambda x: np.array([1,1,1,1])
predictor.initialized = False predictor.initialized = False
E = AtariExpReplay(predictor, predictor, E = AtariExpReplay(predictor, predictor,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment