diff --git a/columbus/observables.py b/columbus/observables.py index a903dbb..c6f34e3 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -250,7 +250,7 @@ class StateObservable(Observable): self.reset() num = len(self.entities)*2+len(self._timeoutEntities) + \ self.speedAgent*2 + self.include_rand - return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, + return spaces.Box(low=0-1*(self.coordsRelativeToAgent or self.speedAgent), high=1, shape=(num,), dtype=np.float64) def get_observation(self):