diff --git a/columbus/observables.py b/columbus/observables.py index 07e3b39..3187a98 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -301,11 +301,14 @@ class CompassObservable(Observable): return self._entities def get_observation_space(self): - self.env.reset() + self.reset() num = len(self.entities)*2 return spaces.Box(low=-1, high=1, shape=(num,), dtype=np.float32) + def reset(self): + self._entities = None + def get_observation(self): obs = [] for entity in self.entities: