Commit 3040586d authored by Yuxin Wu's avatar Yuxin Wu

[DQN] fix wrapper shape

parent bf94458d
......@@ -33,7 +33,6 @@ class FrameStack(gym.Wrapper):
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self._base_dim = len(shp)
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k))
def _reset(self):
......@@ -51,7 +50,7 @@ class FrameStack(gym.Wrapper):
def _observation(self):
assert len(self.frames) == self.k
if self._base_dim == 2:
if self.frames[-1].ndim == 2:
return np.stack(self.frames, axis=-1)
else:
return np.concatenate(self.frames, axis=2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment