Commit 0b2d375d authored by Yuxin Wu's avatar Yuxin Wu

use uint8 image in ALE

parent 76128e42
...@@ -40,6 +40,8 @@ EXPLORATION_EPOCH_ANNEAL = 0.01 ...@@ -40,6 +40,8 @@ EXPLORATION_EPOCH_ANNEAL = 0.01
END_EXPLORATION = 0.1 END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
# NOTE: will consume at least 1e6 * 84 * 84 * 4 bytes = 26G memory.
# Suggest using tcmalloc to manage memory space better.
INIT_MEMORY_SIZE = 5e4 INIT_MEMORY_SIZE = 5e4
STEP_PER_EPOCH = 10000 STEP_PER_EPOCH = 10000
EVAL_EPISODE = 50 EVAL_EPISODE = 50
......
...@@ -54,7 +54,7 @@ class AtariPlayer(RLEnvironment): ...@@ -54,7 +54,7 @@ class AtariPlayer(RLEnvironment):
ALEInterface.setLoggerMode(ALEInterface.Logger.Warning) ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
except AttributeError: except AttributeError:
if execute_only_once(): if execute_only_once():
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!") logger.warn("You're not using latest ALE")
# avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
with _ALE_LOCK: with _ALE_LOCK:
...@@ -104,14 +104,13 @@ class AtariPlayer(RLEnvironment): ...@@ -104,14 +104,13 @@ class AtariPlayer(RLEnvironment):
def current_state(self): def current_state(self):
""" """
:returns: a gray-scale (h, w, 1) float32 image :returns: a gray-scale (h, w, 1) uint8 image
""" """
ret = self._grab_raw_image() ret = self._grab_raw_image()
# max-pooled over the last screen # max-pooled over the last screen
ret = np.maximum(ret, self.last_raw_screen) ret = np.maximum(ret, self.last_raw_screen)
if self.viz: if self.viz:
if isinstance(self.viz, float): if isinstance(self.viz, float):
#m = cv2.resize(ret, (1920,1200))
cv2.imshow(self.windowname, ret) cv2.imshow(self.windowname, ret)
time.sleep(self.viz) time.sleep(self.viz)
ret = ret[self.height_range[0]:self.height_range[1],:].astype('float32') ret = ret[self.height_range[0]:self.height_range[1],:].astype('float32')
...@@ -119,7 +118,7 @@ class AtariPlayer(RLEnvironment): ...@@ -119,7 +118,7 @@ class AtariPlayer(RLEnvironment):
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.resize(ret, self.image_shape) ret = cv2.resize(ret, self.image_shape)
ret = np.expand_dims(ret, axis=2) ret = np.expand_dims(ret, axis=2)
return ret return ret.astype('uint8') # to save some memory
def get_action_space(self): def get_action_space(self):
return DiscreteActionSpace(len(self.actions)) return DiscreteActionSpace(len(self.actions))
......
...@@ -78,14 +78,16 @@ class ExpReplay(DataFlow, Callback): ...@@ -78,14 +78,16 @@ class ExpReplay(DataFlow, Callback):
with tqdm(total=self.init_memory_size) as pbar: with tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size: while len(self.mem) < self.init_memory_size:
#from copy import deepcopy # quickly fill the memory for debug
#self.mem.append(deepcopy(self.mem[0]))
self._populate_exp() self._populate_exp()
pbar.update() pbar.update()
self._init_memory_flag.set() self._init_memory_flag.set()
def _populate_exp(self): def _populate_exp(self):
""" populate a transition by epsilon-greedy""" """ populate a transition by epsilon-greedy"""
#if len(self.mem):
#from copy import deepcopy # quickly fill the memory for debug
#self.mem.append(deepcopy(self.mem[0]))
#return
old_s = self.player.current_state() old_s = self.player.current_state()
if self.rng.rand() <= self.exploration: if self.rng.rand() <= self.exploration:
act = self.rng.choice(range(self.num_actions)) act = self.rng.choice(range(self.num_actions))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment