From f421c92f83c699d9a9394faac38cfa6c5458c605 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 13 Dec 2022 19:45:13 +0100 Subject: [PATCH] Observation spaces now define dtype to be float64 --- columbus/observables.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/columbus/observables.py b/columbus/observables.py index 4290401..a903dbb 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -16,7 +16,7 @@ class Observable(): def get_observation_space(self): print("[!] Using dummyObservable. Env won't output anything") return spaces.Box(low=0, high=1, - shape=(1,), dtype=np.float32) + shape=(1,), dtype=np.float64) def get_observation(self): return np.array([0]) @@ -45,7 +45,7 @@ class CnnObservable(Observable): def get_observation_space(self): return spaces.Box(low=0, high=255, - shape=(self.out_width, self.out_height, 3), dtype=np.float32) + shape=(self.out_width, self.out_height, 3), dtype=np.float64) def get_observation(self): if not self.env._rendered: @@ -251,7 +251,7 @@ class StateObservable(Observable): num = len(self.entities)*2+len(self._timeoutEntities) + \ self.speedAgent*2 + self.include_rand return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, - shape=(num,), dtype=np.float32) + shape=(num,), dtype=np.float64) def get_observation(self): obs = [] @@ -331,7 +331,7 @@ class CompassObservable(Observable): self.reset() num = len(self.entities)*2 return spaces.Box(low=-1, high=1, - shape=(num,), dtype=np.float32) + shape=(num,), dtype=np.float64) def reset(self): self._entities = None @@ -385,7 +385,7 @@ class CompositionalObservable(Observable): low = np.hstack((low, space.low.reshape((-1)))) high = np.hstack((high, space.high.reshape((-1)))) return spaces.Box(low=low, high=high, - shape=(num,), dtype=np.float32) + shape=(num,), dtype=np.float64) def get_observation(self): o = [obs.get_observation().reshape((-1))