Commit 0b2d375d authored by Yuxin Wu's avatar Yuxin Wu

use uint8 image in ALE

parent 76128e42
......@@ -40,6 +40,8 @@ EXPLORATION_EPOCH_ANNEAL = 0.01
END_EXPLORATION = 0.1
MEMORY_SIZE = 1e6
# NOTE: will consume at least 1e6 * 84 * 84 * 4 bytes = 26G memory.
# Suggest using tcmalloc to manage memory space better.
INIT_MEMORY_SIZE = 5e4
STEP_PER_EPOCH = 10000
EVAL_EPISODE = 50
......
......@@ -54,7 +54,7 @@ class AtariPlayer(RLEnvironment):
ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
except AttributeError:
if execute_only_once():
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
logger.warn("You're not using latest ALE")
# avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
with _ALE_LOCK:
......@@ -104,14 +104,13 @@ class AtariPlayer(RLEnvironment):
def current_state(self):
"""
:returns: a gray-scale (h, w, 1) float32 image
:returns: a gray-scale (h, w, 1) uint8 image
"""
ret = self._grab_raw_image()
# max-pooled over the last screen
ret = np.maximum(ret, self.last_raw_screen)
if self.viz:
if isinstance(self.viz, float):
#m = cv2.resize(ret, (1920,1200))
cv2.imshow(self.windowname, ret)
time.sleep(self.viz)
ret = ret[self.height_range[0]:self.height_range[1],:].astype('float32')
......@@ -119,7 +118,7 @@ class AtariPlayer(RLEnvironment):
ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
ret = cv2.resize(ret, self.image_shape)
ret = np.expand_dims(ret, axis=2)
return ret
return ret.astype('uint8') # to save some memory
def get_action_space(self):
return DiscreteActionSpace(len(self.actions))
......
......@@ -78,14 +78,16 @@ class ExpReplay(DataFlow, Callback):
with tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size:
#from copy import deepcopy # quickly fill the memory for debug
#self.mem.append(deepcopy(self.mem[0]))
self._populate_exp()
pbar.update()
self._init_memory_flag.set()
def _populate_exp(self):
""" populate a transition by epsilon-greedy"""
#if len(self.mem):
#from copy import deepcopy # quickly fill the memory for debug
#self.mem.append(deepcopy(self.mem[0]))
#return
old_s = self.player.current_state()
if self.rng.rand() <= self.exploration:
act = self.rng.choice(range(self.num_actions))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment