Commit 4c7348c3 authored by Yuxin Wu's avatar Yuxin Wu

change how expreplay works...

parent 0c5e39eb
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
# File: DQN.py # File: DQN.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy as np import numpy as np
import tensorflow as tf
import os, sys, re import os, sys, re
import random import random
import argparse import argparse
...@@ -22,7 +22,7 @@ from tensorpack.predict import PredictConfig, get_predict_func, ParallelPredictW ...@@ -22,7 +22,7 @@ from tensorpack.predict import PredictConfig, get_predict_func, ParallelPredictW
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer from tensorpack.dataflow.dataset import AtariPlayer
from tensorpack.dataflow.RL import ExpReplay from tensorpack.dataflow.RL import ExpReplay
""" """
...@@ -36,13 +36,13 @@ IMAGE_SIZE = 84 ...@@ -36,13 +36,13 @@ IMAGE_SIZE = 84
NUM_ACTIONS = None NUM_ACTIONS = None
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 3 ACTION_REPEAT = 3
HEIGHT_RANGE = (36, 204) # for breakout #HEIGHT_RANGE = (36, 204) # for breakout
# HEIGHT_RANGE = (28, -8) # for pong HEIGHT_RANGE = (28, -8) # for pong
GAMMA = 0.99 GAMMA = 0.99
BATCH_SIZE = 32 BATCH_SIZE = 32
INIT_EXPLORATION = 1 INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL = 0.0025 EXPLORATION_EPOCH_ANNEAL = 0.0020
END_EXPLORATION = 0.1 END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
...@@ -62,15 +62,20 @@ class Model(ModelDesc): ...@@ -62,15 +62,20 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image, is_training): def _get_DQN_prediction(self, image, is_training):
""" image: [0,255]""" """ image: [0,255]"""
image = image / 128.0 - 1 image = image / 255.0
with argscope(Conv2D, nl=tf.nn.relu, use_bias=True): with argscope(Conv2D, nl=PReLU.f, use_bias=True):
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5, stride=2) l = Conv2D('conv0', image, out_channel=32, kernel_shape=5, stride=1)
l = Conv2D('conv1', l, out_channel=32, kernel_shape=5, stride=2) l = MaxPooling('pool0', l, 2)
l = Conv2D('conv2', l, out_channel=64, kernel_shape=4, stride=2) l = Conv2D('conv1', l, out_channel=32, kernel_shape=5, stride=1)
l = MaxPooling('pool1', l, 2)
l = Conv2D('conv2', l, out_channel=64, kernel_shape=4)
l = MaxPooling('pool2', l, 2)
l = Conv2D('conv3', l, out_channel=64, kernel_shape=3) l = Conv2D('conv3', l, out_channel=64, kernel_shape=3)
l = MaxPooling('pool3', l, 2)
l = Conv2D('conv4', l, out_channel=64, kernel_shape=3)
l = FullyConnected('fc0', l, 512) l = FullyConnected('fc0', l, 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))
l = FullyConnected('fct', l, out_dim=NUM_ACTIONS, nl=tf.identity, summary_activation=False) l = FullyConnected('fct', l, out_dim=NUM_ACTIONS, nl=tf.identity)
return l return l
def _build_graph(self, inputs, is_training): def _build_graph(self, inputs, is_training):
...@@ -136,14 +141,14 @@ def play_one_episode(player, func, verbose=False): ...@@ -136,14 +141,14 @@ def play_one_episode(player, func, verbose=False):
tot_reward = 0 tot_reward = 0
que = deque(maxlen=30) que = deque(maxlen=30)
while True: while True:
s = player.current_state() s = player.current_state() # XXX
outputs = func([[s]]) outputs = func([[s]])
action_value = outputs[0][0] action_value = outputs[0][0]
act = action_value.argmax() act = action_value.argmax()
if verbose: if verbose:
print action_value, act print action_value, act
if random.random() < 0.01: if random.random() < 0.01:
act = random.choice(range(player.driver.get_num_actions())) act = random.choice(range(NUM_ACTIONS))
if len(que) == que.maxlen \ if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen: and que.count(que[0]) == que.maxlen:
act = 1 # hack, avoid stuck act = 1 # hack, avoid stuck
...@@ -156,10 +161,11 @@ def play_one_episode(player, func, verbose=False): ...@@ -156,10 +161,11 @@ def play_one_episode(player, func, verbose=False):
return tot_reward return tot_reward
def play_model(model_path, romfile): def play_model(model_path, romfile):
player = AtariPlayer(AtariDriver(romfile, viz=0.01, height_range=HEIGHT_RANGE), player = HistoryFramePlayer(AtariPlayer(
action_repeat=ACTION_REPEAT) romfile, viz=0.01, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT), FRAME_HISTORY)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions() NUM_ACTIONS = player.player.get_num_actions()
M = Model() M = Model()
cfg = PredictConfig( cfg = PredictConfig(
...@@ -186,10 +192,11 @@ def eval_model_multiprocess(model_path, romfile): ...@@ -186,10 +192,11 @@ def eval_model_multiprocess(model_path, romfile):
self.outq = outqueue self.outq = outqueue
def run(self): def run(self):
player = AtariPlayer(AtariDriver(romfile, viz=0, height_range=HEIGHT_RANGE), player = HistoryFramePlayer(AtariPlayer(
action_repeat=ACTION_REPEAT) romfile, viz=0, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT), FRAME_HISTORY)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions() NUM_ACTIONS = player.player.get_num_actions()
self._init_runtime() self._init_runtime()
while True: while True:
score = play_one_episode(player, self.func) score = play_one_episode(player, self.func)
...@@ -226,15 +233,15 @@ def get_config(romfile): ...@@ -226,15 +233,15 @@ def get_config(romfile):
os.path.join('train_log', basename[:basename.rfind('.')])) os.path.join('train_log', basename[:basename.rfind('.')]))
M = Model() M = Model()
driver = AtariDriver(romfile, height_range=HEIGHT_RANGE) player = AtariPlayer(
romfile, height_range=HEIGHT_RANGE,
frame_skip=ACTION_REPEAT)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = driver.get_num_actions() NUM_ACTIONS = player.get_num_actions()
dataset_train = ExpReplay( dataset_train = ExpReplay(
predictor=current_predictor, predictor=current_predictor,
player=AtariPlayer( player=player,
driver, hist_len=FRAME_HISTORY,
action_repeat=ACTION_REPEAT),
num_actions=NUM_ACTIONS, num_actions=NUM_ACTIONS,
memory_size=MEMORY_SIZE, memory_size=MEMORY_SIZE,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
...@@ -242,22 +249,23 @@ def get_config(romfile): ...@@ -242,22 +249,23 @@ def get_config(romfile):
exploration=INIT_EXPLORATION, exploration=INIT_EXPLORATION,
end_exploration=END_EXPLORATION, end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL, exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
reward_clip=(-1, 2)) reward_clip=(-1, 1),
history_len=FRAME_HISTORY)
lr = tf.Variable(0.0025, trainable=False, name='learning_rate') lr = tf.Variable(0.00025, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
class Evaluator(Callback): class Evaluator(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
logger.info("Evaluating...") logger.info("Evaluating...")
output = subprocess.check_output( output = subprocess.check_output(
"""{} --task eval --rom {} --load {} 2>&1 | grep Average""".format( """CUDA_VISIBLE_DEVICES= {} --task eval --rom {} --load {} 2>&1 | grep Average""".format(
sys.argv[0], romfile, os.path.join(logger.LOG_DIR, 'checkpoint')), shell=True) sys.argv[0], romfile, os.path.join(logger.LOG_DIR, 'checkpoint')), shell=True)
output = output.strip() output = output.strip()
output = output[output.find(']')+1:] output = output[output.find(']')+1:]
mean, maximum = re.findall('[0-9\.]+', output) mean, maximum = re.findall('[0-9\.\-]+', output)[-2:]
self.trainer.write_scalar_summary('eval_mean_score', mean) self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('eval_max_score', maximum) self.trainer.write_scalar_summary('max_score', maximum)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataset=dataset_train,
...@@ -269,7 +277,7 @@ def get_config(romfile): ...@@ -269,7 +277,7 @@ def get_config(romfile):
HumanHyperParamSetter((dataset_train, 'exploration'), 'hyper.txt'), HumanHyperParamSetter((dataset_train, 'exploration'), 'hyper.txt'),
TargetNetworkUpdator(M), TargetNetworkUpdator(M),
dataset_train, dataset_train,
PeriodicCallback(Evaluator(), 1), PeriodicCallback(Evaluator(), 2),
]), ]),
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=M, model=M,
......
...@@ -19,10 +19,10 @@ from tensorpack.callbacks.base import Callback ...@@ -19,10 +19,10 @@ from tensorpack.callbacks.base import Callback
Implement RL-related data preprocessing Implement RL-related data preprocessing
""" """
__all__ = ['ExpReplay', 'RLEnvironment', 'NaiveRLEnvironment'] __all__ = ['ExpReplay', 'RLEnvironment', 'NaiveRLEnvironment', 'HistoryFramePlayer']
Experience = namedtuple('Experience', Experience = namedtuple('Experience',
['state', 'action', 'reward', 'next', 'isOver']) ['state', 'action', 'reward', 'isOver'])
class RLEnvironment(object): class RLEnvironment(object):
__meta__ = ABCMeta __meta__ = ABCMeta
...@@ -65,6 +65,49 @@ class NaiveRLEnvironment(RLEnvironment): ...@@ -65,6 +65,49 @@ class NaiveRLEnvironment(RLEnvironment):
self.k = act self.k = act
return (self.k, self.k > 10) return (self.k, self.k > 10)
class ProxyPlayer(RLEnvironment):
def __init__(self, player):
self.player = player
def get_stat(self):
return self.player.get_stat()
def reset_stat(self):
self.player.reset_stat()
def current_state(self):
return self.player.current_state()
def action(self, act):
return self.player.action(act)
class HistoryFramePlayer(ProxyPlayer):
def __init__(self, player, hist_len):
super(HistoryFramePlayer, self).__init__(player)
self.history = deque(maxlen=hist_len)
s = self.player.current_state()
self.history.append(s)
def current_state(self):
assert len(self.history) != 0
diff_len = self.history.maxlen - len(self.history)
if diff_len == 0:
return np.concatenate(self.history, axis=2)
zeros = [np.zeros_like(self.history[0]) for k in range(diff_len)]
for k in self.history:
zeros.append(k)
return np.concatenate(zeros, axis=2)
def action(self, act):
r, isOver = self.player.action(act)
s = self.player.current_state()
self.history.append(s)
if isOver: # s would be a new episode
self.history.clear()
self.history.append(s)
return (r, isOver)
class ExpReplay(DataFlow, Callback): class ExpReplay(DataFlow, Callback):
""" """
...@@ -82,11 +125,15 @@ class ExpReplay(DataFlow, Callback): ...@@ -82,11 +125,15 @@ class ExpReplay(DataFlow, Callback):
end_exploration=0.1, end_exploration=0.1,
exploration_epoch_anneal=0.002, exploration_epoch_anneal=0.002,
reward_clip=None, reward_clip=None,
new_experience_per_step=1 new_experience_per_step=1,
history_len=1
): ):
""" """
:param predictor: callabale. called with a state, return a distribution :param predictor: a callabale calling the up-to-date network.
called with a state, return a distribution
:param player: a `RLEnvironment` :param player: a `RLEnvironment`
:param num_actions: int
:param history_len: length of history frames to concat. zero-filled initial frames
""" """
for k, v in locals().items(): for k, v in locals().items():
if k != 'self': if k != 'self':
...@@ -106,51 +153,83 @@ class ExpReplay(DataFlow, Callback): ...@@ -106,51 +153,83 @@ class ExpReplay(DataFlow, Callback):
raise RuntimeError("Don't run me in multiple processes") raise RuntimeError("Don't run me in multiple processes")
def _populate_exp(self): def _populate_exp(self):
p = self.rng.rand()
old_s = self.player.current_state() old_s = self.player.current_state()
if p <= self.exploration: if self.rng.rand() <= self.exploration:
act = self.rng.choice(range(self.num_actions)) act = self.rng.choice(range(self.num_actions))
else: else:
act = np.argmax(self.predictor(old_s)) # TODO race condition in session? # build a history state
ss = [old_s]
for k in range(1, self.history_len):
hist_exp = self.mem[-k]
if hist_exp.isOver:
ss.append(np.zeros_like(ss[0]))
else:
ss.append(hist_exp.state)
ss = np.concatenate(ss, axis=2)
act = np.argmax(self.predictor(ss))
reward, isOver = self.player.action(act) reward, isOver = self.player.action(act)
if self.reward_clip: if self.reward_clip:
reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1]) reward = np.clip(reward, self.reward_clip[0], self.reward_clip[1])
s = self.player.current_state()
#def view_state(state):
#""" for debug state representation"""
#r = np.concatenate([state[:,:,k] for k in range(state.shape[2])], axis=1)
#print r.shape
#cv2.imshow("state", r)
#cv2.waitKey()
#print act, reward
#view_state(s)
# s is considered useless if isOver==True self.mem.append(Experience(old_s, act, reward, isOver))
self.mem.append(Experience(old_s, act, reward, s, isOver))
def get_data(self): def get_data(self):
# new s is considered useless if isOver==True
while True: while True:
idxs = self.rng.randint(len(self.mem), size=self.batch_size) batch_exp = [self.sample_one() for _ in range(self.batch_size)]
batch_exp = [self.mem[k] for k in idxs]
def view_state(state, next_state):
""" for debug state representation"""
r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
r2 = np.concatenate([next_state[:,:,k] for k in range(self.history_len)], axis=1)
print r.shape
r = np.concatenate([r, r2], axis=0)
cv2.imshow("state", r)
cv2.waitKey()
exp = batch_exp[0]
print("Act: ", exp[3], " reward:", exp[2], " isOver: ", exp[4])
view_state(exp[0], exp[1])
yield self._process_batch(batch_exp) yield self._process_batch(batch_exp)
for _ in range(self.new_experience_per_step): for _ in range(self.new_experience_per_step):
self._populate_exp() self._populate_exp()
def sample_one(self):
""" return the transition tuple for
[idx, idx+history_len] -> [idx+1, idx+1+history_len]
it's the transition from state idx+history_len-1 to state idx+history_len
"""
# look for a state to start with
# when x.isOver==True, (x+1).state is of a different episode
idx = self.rng.randint(len(self.mem) - self.history_len - 1)
start_idx = idx + self.history_len - 1
def concat(idx):
v = [self.mem[x].state for x in range(idx, idx+self.history_len)]
return np.concatenate(v, axis=2)
state = concat(idx)
next_state = concat(idx + 1)
reward = self.mem[start_idx].reward
action = self.mem[start_idx].action
isOver = self.mem[start_idx].isOver
# zero-fill state before starting
zero_fill = False
for k in range(1, self.history_len):
if self.mem[start_idx-k].isOver:
zero_fill = True
if zero_fill:
state[:,:,-k-1] = 0
if k + 2 <= self.history_len:
next_state[:,:,-k-2] = 0
return (state, next_state, reward, action, isOver)
def _process_batch(self, batch_exp): def _process_batch(self, batch_exp):
state_shape = batch_exp[0].state.shape state = np.array([e[0] for e in batch_exp])
state = np.zeros((self.batch_size, ) + state_shape, dtype='float32') next_state = np.array([e[1] for e in batch_exp])
next_state = np.zeros((self.batch_size, ) + state_shape, dtype='float32') reward = np.array([e[2] for e in batch_exp])
reward = np.zeros((self.batch_size,), dtype='float32') action = np.array([e[3] for e in batch_exp])
action = np.zeros((self.batch_size,), dtype='int32') isOver = np.array([e[4] for e in batch_exp])
isOver = np.zeros((self.batch_size,), dtype='bool')
for idx, b in enumerate(batch_exp):
state[idx] = b.state
action[idx] = b.action
next_state[idx] = b.next
reward[idx] = b.reward
isOver[idx] = b.isOver
return [state, action, reward, next_state, isOver] return [state, action, reward, next_state, isOver]
# Callback-related: # Callback-related:
...@@ -170,12 +249,16 @@ class ExpReplay(DataFlow, Callback): ...@@ -170,12 +249,16 @@ class ExpReplay(DataFlow, Callback):
if __name__ == '__main__': if __name__ == '__main__':
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer from tensorpack.dataflow.dataset import AtariPlayer
import sys
predictor = lambda x: np.array([1,1,1,1]) predictor = lambda x: np.array([1,1,1,1])
predictor.initialized = False predictor.initialized = False
E = AtariExpReplay(predictor, predictor, player = AtariPlayer(sys.argv[1], viz=0, frame_skip=20)
AtariPlayer(AtariDriver('../../space_invaders.bin', viz=0.01)), E = ExpReplay(predictor,
populate_size=1000) player=player,
num_actions=player.get_num_actions(),
populate_size=1001,
history_len=4)
E.init_memory() E.init_memory()
for k in E.get_data(): for k in E.get_data():
......
...@@ -9,6 +9,7 @@ import os ...@@ -9,6 +9,7 @@ import os
import cv2 import cv2
from collections import deque from collections import deque
from ...utils import get_rng, logger from ...utils import get_rng, logger
from ...utils.stat import StatCounter
from ..RL import RLEnvironment from ..RL import RLEnvironment
try: try:
...@@ -16,23 +17,27 @@ try: ...@@ -16,23 +17,27 @@ try:
except ImportError: except ImportError:
logger.warn("Cannot import ale_python_interface, Atari won't be available.") logger.warn("Cannot import ale_python_interface, Atari won't be available.")
__all__ = ['AtariDriver', 'AtariPlayer'] __all__ = ['AtariPlayer']
class AtariDriver(RLEnvironment): class AtariPlayer(RLEnvironment):
""" """
A wrapper for atari emulator. A wrapper for atari emulator.
""" """
def __init__(self, rom_file, viz=0, height_range=(None,None)): def __init__(self, rom_file, viz=0, height_range=(None,None),
frame_skip=4, image_shape=(84, 84)):
""" """
:param rom_file: path to the rom :param rom_file: path to the rom
:param frame_skip: skip every k frames :param frame_skip: skip every k frames
:param image_shape: (w, h)
:param height_range: (h1, h2) to cut
:param viz: the delay. visualize the game while running. 0 to disable :param viz: the delay. visualize the game while running. 0 to disable
""" """
super(AtariPlayer, self).__init__()
self.ale = ALEInterface() self.ale = ALEInterface()
self.rng = get_rng(self) self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(self.rng.randint(0, 1000))) self.ale.setInt("random_seed", self.rng.randint(self.rng.randint(0, 1000)))
self.ale.setInt("frame_skip", 1) self.ale.setInt("frame_skip", frame_skip)
self.ale.setBool('color_averaging', True) self.ale.setBool('color_averaging', True)
self.ale.loadROM(rom_file) self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims() self.width, self.height = self.ale.getScreenDims()
...@@ -45,9 +50,11 @@ class AtariDriver(RLEnvironment): ...@@ -45,9 +50,11 @@ class AtariDriver(RLEnvironment):
if self.viz and isinstance(self.viz, float): if self.viz and isinstance(self.viz, float):
cv2.startWindowThread() cv2.startWindowThread()
cv2.namedWindow(self.romname) cv2.namedWindow(self.romname)
self.framenum = 0
self.height_range = height_range self.height_range = height_range
self.framenum = 0 self.image_shape = image_shape
self.current_episode_score = StatCounter()
self._reset() self._reset()
...@@ -61,9 +68,9 @@ class AtariDriver(RLEnvironment): ...@@ -61,9 +68,9 @@ class AtariDriver(RLEnvironment):
def current_state(self): def current_state(self):
""" """
:returns: a gray-scale image, max-pooled over the last frame. :returns: a gray-scale (h, w, 1) image
""" """
now = self._grab_raw_image() ret = self._grab_raw_image()
if self.viz: if self.viz:
if isinstance(self.viz, float): if isinstance(self.viz, float):
cv2.imshow(self.romname, ret) cv2.imshow(self.romname, ret)
...@@ -73,6 +80,8 @@ class AtariDriver(RLEnvironment): ...@@ -73,6 +80,8 @@ class AtariDriver(RLEnvironment):
self.framenum += 1 self.framenum += 1
ret = ret[self.height_range[0]:self.height_range[1],:] ret = ret[self.height_range[0]:self.height_range[1],:]
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0] ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
ret = cv2.resize(ret, self.image_shape)
ret = np.expand_dims(ret, axis=2)
return ret return ret
def get_num_actions(self): def get_num_actions(self):
...@@ -82,6 +91,7 @@ class AtariDriver(RLEnvironment): ...@@ -82,6 +91,7 @@ class AtariDriver(RLEnvironment):
return len(self.actions) return len(self.actions)
def _reset(self): def _reset(self):
self.current_episode_score.reset()
self.ale.reset_game() self.ale.reset_game()
def action(self, act): def action(self, act):
...@@ -90,80 +100,13 @@ class AtariDriver(RLEnvironment): ...@@ -90,80 +100,13 @@ class AtariDriver(RLEnvironment):
:returns: (reward, isOver) :returns: (reward, isOver)
""" """
r = self.ale.act(self.actions[act]) r = self.ale.act(self.actions[act])
self.current_episode_score.feed(r)
isOver = self.ale.game_over() isOver = self.ale.game_over()
if isOver: if isOver:
self.stats['score'].append(self.current_episode_score.sum)
self._reset() self._reset()
return (r, isOver) return (r, isOver)
class AtariPlayer(RLEnvironment):
""" An Atari game player with limited memory and FPS"""
def __init__(self, driver, hist_len=4, action_repeat=4, image_shape=(84,84)):
"""
:param driver: an `AtariDriver` instance.
:param hist_len: history(memory) length
:param action_repeat: repeat each action `action_repeat` times and skip those frames
:param image_shape: the shape of the observed image
"""
super(AtariPlayer, self).__init__()
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.last_act = 0
self.frames = deque(maxlen=hist_len)
self.current_accum_score = 0
self.restart()
def restart(self):
"""
Restart the game and populate frames with the beginning frame
"""
self.current_accum_score = 0
self.frames.clear()
s = self.driver.current_state()
s = cv2.resize(s, self.image_shape)
for _ in range(self.hist_len):
self.frames.append(s)
def current_state(self):
"""
Return a current state of shape `image_shape + (hist_len,)`
"""
return self._build_state()
def action(self, act):
"""
Perform an action
:param act: index of the action
:returns: (reward, isOver)
"""
self.last_act = act
return self._observe()
def _build_state(self):
assert len(self.frames) == self.hist_len
m = np.array(self.frames)
m = m.transpose([1,2,0])
return m
def _observe(self):
""" if isOver==True, current_state will return the new episode
"""
totr = 0
for k in range(self.action_repeat):
r, isOver = self.driver.action(self.last_act)
s = self.driver.current_state()
totr += r
if isOver:
break
s = cv2.resize(s, self.image_shape)
self.current_accum_score += totr
self.frames.append(s)
if isOver:
self.stats['score'].append(self.current_accum_score)
self.restart()
return (totr, isOver)
def get_stat(self): def get_stat(self):
try: try:
return {'avg_score': np.mean(self.stats['score']), return {'avg_score': np.mean(self.stats['score']),
...@@ -173,7 +116,8 @@ class AtariPlayer(RLEnvironment): ...@@ -173,7 +116,8 @@ class AtariPlayer(RLEnvironment):
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
a = AtariDriver(sys.argv[1], viz=0.01, height_range=(28,-8)) a = AtariPlayer(sys.argv[1],
viz=0.01, height_range=(28,-8))
num = a.get_num_actions() num = a.get_num_actions()
rng = get_rng(num) rng = get_rng(num)
import time import time
...@@ -183,6 +127,7 @@ if __name__ == '__main__': ...@@ -183,6 +127,7 @@ if __name__ == '__main__':
act = rng.choice(range(num)) act = rng.choice(range(num))
print act print act
r, o = a.action(act) r, o = a.action(act)
a.current_state()
#time.sleep(0.1) #time.sleep(0.1)
print(r, o) print(r, o)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment