Commit 3040586d authored by Yuxin Wu's avatar Yuxin Wu

[DQN] fix wrapper shape

parent bf94458d
...@@ -33,7 +33,6 @@ class FrameStack(gym.Wrapper): ...@@ -33,7 +33,6 @@ class FrameStack(gym.Wrapper):
self.frames = deque([], maxlen=k) self.frames = deque([], maxlen=k)
shp = env.observation_space.shape shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2] chan = 1 if len(shp) == 2 else shp[2]
self._base_dim = len(shp)
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k)) self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k))
def _reset(self): def _reset(self):
...@@ -51,7 +50,7 @@ class FrameStack(gym.Wrapper): ...@@ -51,7 +50,7 @@ class FrameStack(gym.Wrapper):
def _observation(self): def _observation(self):
assert len(self.frames) == self.k assert len(self.frames) == self.k
if self._base_dim == 2: if self.frames[-1].ndim == 2:
return np.stack(self.frames, axis=-1) return np.stack(self.frames, axis=-1)
else: else:
return np.concatenate(self.frames, axis=2) return np.concatenate(self.frames, axis=2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment