Commit 0c5e39eb authored by Yuxin Wu's avatar Yuxin Wu

ataridriver as an rlenv

parent d5d7270a
...@@ -37,6 +37,7 @@ NUM_ACTIONS = None ...@@ -37,6 +37,7 @@ NUM_ACTIONS = None
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 3 ACTION_REPEAT = 3
HEIGHT_RANGE = (36, 204) # for breakout HEIGHT_RANGE = (36, 204) # for breakout
# HEIGHT_RANGE = (28, -8) # for pong
GAMMA = 0.99 GAMMA = 0.99
BATCH_SIZE = 32 BATCH_SIZE = 32
......
...@@ -2,6 +2,4 @@ termcolor ...@@ -2,6 +2,4 @@ termcolor
pillow pillow
scipy scipy
tqdm tqdm
h5py
nltk
dill dill
...@@ -126,6 +126,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -126,6 +126,7 @@ class ExpReplay(DataFlow, Callback):
#print act, reward #print act, reward
#view_state(s) #view_state(s)
# s is considered useless if isOver==True
self.mem.append(Experience(old_s, act, reward, s, isOver)) self.mem.append(Experience(old_s, act, reward, s, isOver))
def get_data(self): def get_data(self):
......
...@@ -18,12 +18,11 @@ except ImportError: ...@@ -18,12 +18,11 @@ except ImportError:
__all__ = ['AtariDriver', 'AtariPlayer'] __all__ = ['AtariDriver', 'AtariPlayer']
class AtariDriver(object): class AtariDriver(RLEnvironment):
""" """
A wrapper for atari emulator. A wrapper for atari emulator.
""" """
def __init__(self, rom_file, def __init__(self, rom_file, viz=0, height_range=(None,None)):
frame_skip=1, viz=0, height_range=(None,None)):
""" """
:param rom_file: path to the rom :param rom_file: path to the rom
:param frame_skip: skip every k frames :param frame_skip: skip every k frames
...@@ -33,7 +32,8 @@ class AtariDriver(object): ...@@ -33,7 +32,8 @@ class AtariDriver(object):
self.rng = get_rng(self) self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(self.rng.randint(0, 1000))) self.ale.setInt("random_seed", self.rng.randint(self.rng.randint(0, 1000)))
self.ale.setInt("frame_skip", frame_skip) self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', True)
self.ale.loadROM(rom_file) self.ale.loadROM(rom_file)
self.width, self.height = self.ale.getScreenDims() self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet() self.actions = self.ale.getMinimalActionSet()
...@@ -46,10 +46,10 @@ class AtariDriver(object): ...@@ -46,10 +46,10 @@ class AtariDriver(object):
cv2.startWindowThread() cv2.startWindowThread()
cv2.namedWindow(self.romname) cv2.namedWindow(self.romname)
self._reset()
self.last_image = self._grab_raw_image()
self.framenum = 0
self.height_range = height_range self.height_range = height_range
self.framenum = 0
self._reset()
def _grab_raw_image(self): def _grab_raw_image(self):
""" """
...@@ -59,14 +59,11 @@ class AtariDriver(object): ...@@ -59,14 +59,11 @@ class AtariDriver(object):
self.ale.getScreenRGB(m) self.ale.getScreenRGB(m)
return m.reshape((self.height, self.width, 3)) return m.reshape((self.height, self.width, 3))
def grab_image(self): def current_state(self):
""" """
:returns: a gray-scale image, max-pooled over the last frame. :returns: a gray-scale image, max-pooled over the last frame.
""" """
now = self._grab_raw_image() now = self._grab_raw_image()
ret = np.maximum(now, self.last_image)
self.last_image = now
ret = ret[self.height_range[0]:self.height_range[1],:] # several online repos all use this
if self.viz: if self.viz:
if isinstance(self.viz, float): if isinstance(self.viz, float):
cv2.imshow(self.romname, ret) cv2.imshow(self.romname, ret)
...@@ -74,6 +71,7 @@ class AtariDriver(object): ...@@ -74,6 +71,7 @@ class AtariDriver(object):
else: else:
cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret) cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret)
self.framenum += 1 self.framenum += 1
ret = ret[self.height_range[0]:self.height_range[1],:]
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0] ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
return ret return ret
...@@ -86,17 +84,16 @@ class AtariDriver(object): ...@@ -86,17 +84,16 @@ class AtariDriver(object):
def _reset(self): def _reset(self):
self.ale.reset_game() self.ale.reset_game()
def next(self, act): def action(self, act):
""" """
:param act: an index of the action :param act: an index of the action
:returns: (next_image, reward, isOver) :returns: (reward, isOver)
""" """
r = self.ale.act(self.actions[act]) r = self.ale.act(self.actions[act])
s = self.grab_image()
isOver = self.ale.game_over() isOver = self.ale.game_over()
if isOver: if isOver:
self._reset() self._reset()
return (s, r, isOver) return (r, isOver)
class AtariPlayer(RLEnvironment): class AtariPlayer(RLEnvironment):
""" An Atari game player with limited memory and FPS""" """ An Atari game player with limited memory and FPS"""
...@@ -122,7 +119,7 @@ class AtariPlayer(RLEnvironment): ...@@ -122,7 +119,7 @@ class AtariPlayer(RLEnvironment):
""" """
self.current_accum_score = 0 self.current_accum_score = 0
self.frames.clear() self.frames.clear()
s = self.driver.grab_image() s = self.driver.current_state()
s = cv2.resize(s, self.image_shape) s = cv2.resize(s, self.image_shape)
for _ in range(self.hist_len): for _ in range(self.hist_len):
...@@ -138,7 +135,7 @@ class AtariPlayer(RLEnvironment): ...@@ -138,7 +135,7 @@ class AtariPlayer(RLEnvironment):
""" """
Perform an action Perform an action
:param act: index of the action :param act: index of the action
:returns: (new_frame, reward, isOver) :returns: (reward, isOver)
""" """
self.last_act = act self.last_act = act
return self._observe() return self._observe()
...@@ -154,7 +151,8 @@ class AtariPlayer(RLEnvironment): ...@@ -154,7 +151,8 @@ class AtariPlayer(RLEnvironment):
""" """
totr = 0 totr = 0
for k in range(self.action_repeat): for k in range(self.action_repeat):
s, r, isOver = self.driver.next(self.last_act) r, isOver = self.driver.action(self.last_act)
s = self.driver.current_state()
totr += r totr += r
if isOver: if isOver:
break break
...@@ -174,7 +172,8 @@ class AtariPlayer(RLEnvironment): ...@@ -174,7 +172,8 @@ class AtariPlayer(RLEnvironment):
return {} return {}
if __name__ == '__main__': if __name__ == '__main__':
a = AtariDriver('breakout.bin', viz=True) import sys
a = AtariDriver(sys.argv[1], viz=0.01, height_range=(28,-8))
num = a.get_num_actions() num = a.get_num_actions()
rng = get_rng(num) rng = get_rng(num)
import time import time
...@@ -182,7 +181,8 @@ if __name__ == '__main__': ...@@ -182,7 +181,8 @@ if __name__ == '__main__':
#im = a.grab_image() #im = a.grab_image()
#cv2.imshow(a.romname, im) #cv2.imshow(a.romname, im)
act = rng.choice(range(num)) act = rng.choice(range(num))
s, r, o = a.next(act) print act
time.sleep(0.1) r, o = a.action(act)
#time.sleep(0.1)
print(r, o) print(r, o)
...@@ -118,7 +118,7 @@ class ParallelPredictWorker(multiprocessing.Process): ...@@ -118,7 +118,7 @@ class ParallelPredictWorker(multiprocessing.Process):
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
else: else:
logger.info("Worker {} uses CPU".format(self.idx)) logger.info("Worker {} uses CPU".format(self.idx))
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = ''
G = tf.Graph() # build a graph for each process, because they don't need to share anything G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:0' if self.gpuid >= 0 else '/cpu:0'): with G.as_default(), tf.device('/gpu:0' if self.gpuid >= 0 else '/cpu:0'):
if self.idx != 0: if self.idx != 0:
......
...@@ -34,6 +34,9 @@ def add_activation_summary(x, name=None): ...@@ -34,6 +34,9 @@ def add_activation_summary(x, name=None):
name = x.name name = x.name
tf.histogram_summary(name + '/activation', x) tf.histogram_summary(name + '/activation', x)
tf.scalar_summary(name + '/activation_sparsity', tf.nn.zero_fraction(x)) tf.scalar_summary(name + '/activation_sparsity', tf.nn.zero_fraction(x))
tf.scalar_summary(
name + '/activation_rms',
tf.sqrt(tf.reduce_mean(tf.square(x))))
def add_param_summary(summary_lists): def add_param_summary(summary_lists):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment