Observation spaces now define dtype to be float64

This commit is contained in:
Dominik Moritz Roth 2022-12-13 19:45:13 +01:00
parent 164c72504c
commit f421c92f83

View File

@ -16,7 +16,7 @@ class Observable():
def get_observation_space(self): def get_observation_space(self):
print("[!] Using dummyObservable. Env won't output anything") print("[!] Using dummyObservable. Env won't output anything")
return spaces.Box(low=0, high=1, return spaces.Box(low=0, high=1,
shape=(1,), dtype=np.float32) shape=(1,), dtype=np.float64)
def get_observation(self): def get_observation(self):
return np.array([0]) return np.array([0])
@ -45,7 +45,7 @@ class CnnObservable(Observable):
def get_observation_space(self): def get_observation_space(self):
return spaces.Box(low=0, high=255, return spaces.Box(low=0, high=255,
shape=(self.out_width, self.out_height, 3), dtype=np.float32) shape=(self.out_width, self.out_height, 3), dtype=np.float64)
def get_observation(self): def get_observation(self):
if not self.env._rendered: if not self.env._rendered:
@ -251,7 +251,7 @@ class StateObservable(Observable):
num = len(self.entities)*2+len(self._timeoutEntities) + \ num = len(self.entities)*2+len(self._timeoutEntities) + \
self.speedAgent*2 + self.include_rand self.speedAgent*2 + self.include_rand
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
shape=(num,), dtype=np.float32) shape=(num,), dtype=np.float64)
def get_observation(self): def get_observation(self):
obs = [] obs = []
@ -331,7 +331,7 @@ class CompassObservable(Observable):
self.reset() self.reset()
num = len(self.entities)*2 num = len(self.entities)*2
return spaces.Box(low=-1, high=1, return spaces.Box(low=-1, high=1,
shape=(num,), dtype=np.float32) shape=(num,), dtype=np.float64)
def reset(self): def reset(self):
self._entities = None self._entities = None
@ -385,7 +385,7 @@ class CompositionalObservable(Observable):
low = np.hstack((low, space.low.reshape((-1)))) low = np.hstack((low, space.low.reshape((-1))))
high = np.hstack((high, space.high.reshape((-1)))) high = np.hstack((high, space.high.reshape((-1))))
return spaces.Box(low=low, high=high, return spaces.Box(low=low, high=high,
shape=(num,), dtype=np.float32) shape=(num,), dtype=np.float64)
def get_observation(self): def get_observation(self):
o = [obs.get_observation().reshape((-1)) o = [obs.get_observation().reshape((-1))