Commit 961b0ee4 authored by Yuxin Wu's avatar Yuxin Wu

move exp_replay

parent fc0b965a
......@@ -245,15 +245,17 @@ def get_config(romfile):
global NUM_ACTIONS
NUM_ACTIONS = driver.get_num_actions()
dataset_train = AtariExpReplay(
dataset_train = ExpReplay(
predictor=current_predictor,
player=AtariPlayer(
driver, hist_len=FRAME_HISTORY,
action_repeat=ACTION_REPEAT),
num_actions=NUM_ACTIONS,
memory_size=MEMORY_SIZE,
batch_size=BATCH_SIZE,
populate_size=INIT_MEMORY_SIZE,
exploration=INIT_EXPLORATION)
exploration=INIT_EXPLORATION,
reward_clip=(-1, 2))
lr = tf.Variable(0.0025, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
......
......@@ -2,8 +2,9 @@
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy
import cv2 # fix https://github.com/tensorflow/tensorflow/issues/1924
import numpy # avoid https://github.com/tensorflow/tensorflow/issues/2034
import cv2 # avoid https://github.com/tensorflow/tensorflow/issues/1924
from . import models
from . import train
from . import utils
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: exp_replay.py
# File: RL.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from tensorpack.dataflow import *
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
from .base import DataFlow
from tensorpack.utils import *
from tqdm import tqdm
import random
import numpy as np
import cv2
from collections import deque, namedtuple
"""
Implement RL-related data preprocessing
"""
__all__ = ['ExpReplay']
Experience = namedtuple('Experience',
['state', 'action', 'reward', 'next', 'isOver'])
def view_state(state):
# for debug
r = np.concatenate([state[:,:,k] for k in range(state.shape[2])], axis=1)
print r.shape
cv2.imshow("state", r)
cv2.waitKey()
class AtariExpReplay(DataFlow):
class ExpReplay(DataFlow):
"""
Implement experience replay
Implement experience replay.
"""
def __init__(self,
predictor,
player,
num_actions,
memory_size=1e6,
batch_size=32,
populate_size=50000,
exploration=1):
exploration=1,
reward_clip=None):
"""
:param predictor: callabale. called with a state, return a distribution
:param player: a `RLEnvironment`
"""
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.num_actions = self.player.driver.get_num_actions()
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=memory_size)
self.rng = get_rng(self)
......@@ -62,7 +70,8 @@ class AtariExpReplay(DataFlow):
else:
act = np.argmax(self.predictor(old_s)) # TODO race condition in session?
reward, isOver = self.player.action(act)
reward = np.clip(reward, -1, 2)
if self.reward_clip:
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
s = self.player.current_state()
#print act, reward
......@@ -94,6 +103,7 @@ class AtariExpReplay(DataFlow):
return [state, action, reward, next_state, isOver]
if __name__ == '__main__':
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
predictor = lambda x: np.array([1,1,1,1])
predictor.initialized = False
E = AtariExpReplay(predictor, predictor,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment