Commit 0c5e39eb authored by Yuxin Wu's avatar Yuxin Wu

ataridriver as an rlenv

parent d5d7270a
......@@ -37,6 +37,7 @@ NUM_ACTIONS = None
FRAME_HISTORY = 4
ACTION_REPEAT = 3
HEIGHT_RANGE = (36, 204) # for breakout
# HEIGHT_RANGE = (28, -8) # for pong
GAMMA = 0.99
BATCH_SIZE = 32
......
......@@ -126,6 +126,7 @@ class ExpReplay(DataFlow, Callback):
#print act, reward
#view_state(s)
# s is considered useless if isOver==True
self.mem.append(Experience(old_s, act, reward, s, isOver))
def get_data(self):
......
......@@ -18,12 +18,11 @@ except ImportError:
__all__ = ['AtariDriver', 'AtariPlayer']
class AtariDriver(object):
class AtariDriver(RLEnvironment):
"""
A wrapper for atari emulator.
"""
def __init__(self, rom_file,
frame_skip=1, viz=0, height_range=(None,None)):
def __init__(self, rom_file, viz=0, height_range=(None,None)):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames
......@@ -33,7 +32,8 @@ class AtariDriver(object):
self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(self.rng.randint(0, 1000)))
self.ale.setInt("frame_skip", frame_skip)
self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', True)
self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
......@@ -46,10 +46,10 @@ class AtariDriver(object):
cv2.startWindowThread()
cv2.namedWindow(self.romname)
self._reset()
self.last_image = self._grab_raw_image()
self.framenum = 0
self.height_range = height_range
self.framenum = 0
self._reset()
def _grab_raw_image(self):
"""
......@@ -59,14 +59,11 @@ class AtariDriver(object):
self.ale.getScreenRGB(m)
return m.reshape((self.height, self.width, 3))
def grab_image(self):
def current_state(self):
"""
:returns: a gray-scale image, max-pooled over the last frame.
"""
now = self._grab_raw_image()
ret = np.maximum(now, self.last_image)
self.last_image = now
ret = ret[self.height_range[0]:self.height_range[1],:] # several online repos all use this
if self.viz:
if isinstance(self.viz, float):
cv2.imshow(self.romname, ret)
......@@ -74,6 +71,7 @@ class AtariDriver(object):
else:
cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret)
self.framenum += 1
ret = ret[self.height_range[0]:self.height_range[1],:]
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
return ret
......@@ -86,17 +84,16 @@ class AtariDriver(object):
def _reset(self):
self.ale.reset_game()
def next(self, act):
def action(self, act):
"""
:param act: an index of the action
:returns: (next_image, reward, isOver)
:returns: (reward, isOver)
"""
r = self.ale.act(self.actions[act])
s = self.grab_image()
isOver = self.ale.game_over()
if isOver:
self._reset()
return (s, r, isOver)
return (r, isOver)
class AtariPlayer(RLEnvironment):
""" An Atari game player with limited memory and FPS"""
......@@ -122,7 +119,7 @@ class AtariPlayer(RLEnvironment):
"""
self.current_accum_score = 0
self.frames.clear()
s = self.driver.grab_image()
s = self.driver.current_state()
s = cv2.resize(s, self.image_shape)
for _ in range(self.hist_len):
......@@ -138,7 +135,7 @@ class AtariPlayer(RLEnvironment):
"""
Perform an action
:param act: index of the action
:returns: (new_frame, reward, isOver)
:returns: (reward, isOver)
"""
self.last_act = act
return self._observe()
......@@ -154,7 +151,8 @@ class AtariPlayer(RLEnvironment):
"""
totr = 0
for k in range(self.action_repeat):
s, r, isOver = self.driver.next(self.last_act)
r, isOver = self.driver.action(self.last_act)
s = self.driver.current_state()
totr += r
if isOver:
break
......@@ -174,7 +172,8 @@ class AtariPlayer(RLEnvironment):
return {}
if __name__ == '__main__':
a = AtariDriver('breakout.bin', viz=True)
import sys
a = AtariDriver(sys.argv[1], viz=0.01, height_range=(28,-8))
num = a.get_num_actions()
rng = get_rng(num)
import time
......@@ -182,7 +181,8 @@ if __name__ == '__main__':
#im = a.grab_image()
#cv2.imshow(a.romname, im)
act = rng.choice(range(num))
s, r, o = a.next(act)
time.sleep(0.1)
print act
r, o = a.action(act)
#time.sleep(0.1)
print(r, o)
......@@ -118,7 +118,7 @@ class ParallelPredictWorker(multiprocessing.Process):
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
else:
logger.info("Worker {} uses CPU".format(self.idx))
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['CUDA_VISIBLE_DEVICES'] = ''
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:0' if self.gpuid >= 0 else '/cpu:0'):
if self.idx != 0:
......
......@@ -34,6 +34,9 @@ def add_activation_summary(x, name=None):
name = x.name
tf.histogram_summary(name + '/activation', x)
tf.scalar_summary(name + '/activation_sparsity', tf.nn.zero_fraction(x))
tf.scalar_summary(
name + '/activation_rms',
tf.sqrt(tf.reduce_mean(tf.square(x))))
def add_param_summary(summary_lists):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment