diff --git a/columbus/observables.py b/columbus/observables.py index 3a17994..6189ce2 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -219,7 +219,7 @@ class StateObservable(Observable): def get_observation_space(self): self.env.reset() num = len(self.entities)*2+len(self._timeoutEntities) + \ - self.speedAgent + self.include_rand + self.speedAgent*2 + self.include_rand return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, shape=(num,), dtype=np.float32) @@ -241,8 +241,8 @@ class StateObservable(Observable): for entity in self._timeoutEntities: obs.append(entity.active) if self.speedAgent: - obs.append(self.env.speed[0]) - obs.append(self.env.speed[1]) + obs.append(self.env.agent.speed[0]) + obs.append(self.env.agent.speed[1]) if self.include_rand: obs.append(self.env.random()) self.obs = obs