Observation spaces now define dtype to be float64
This commit is contained in:
parent
164c72504c
commit
f421c92f83
@ -16,7 +16,7 @@ class Observable():
|
||||
def get_observation_space(self):
|
||||
print("[!] Using dummyObservable. Env won't output anything")
|
||||
return spaces.Box(low=0, high=1,
|
||||
shape=(1,), dtype=np.float32)
|
||||
shape=(1,), dtype=np.float64)
|
||||
|
||||
def get_observation(self):
|
||||
return np.array([0])
|
||||
@ -45,7 +45,7 @@ class CnnObservable(Observable):
|
||||
|
||||
def get_observation_space(self):
|
||||
return spaces.Box(low=0, high=255,
|
||||
shape=(self.out_width, self.out_height, 3), dtype=np.float32)
|
||||
shape=(self.out_width, self.out_height, 3), dtype=np.float64)
|
||||
|
||||
def get_observation(self):
|
||||
if not self.env._rendered:
|
||||
@ -251,7 +251,7 @@ class StateObservable(Observable):
|
||||
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
||||
self.speedAgent*2 + self.include_rand
|
||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||
shape=(num,), dtype=np.float32)
|
||||
shape=(num,), dtype=np.float64)
|
||||
|
||||
def get_observation(self):
|
||||
obs = []
|
||||
@ -331,7 +331,7 @@ class CompassObservable(Observable):
|
||||
self.reset()
|
||||
num = len(self.entities)*2
|
||||
return spaces.Box(low=-1, high=1,
|
||||
shape=(num,), dtype=np.float32)
|
||||
shape=(num,), dtype=np.float64)
|
||||
|
||||
def reset(self):
|
||||
self._entities = None
|
||||
@ -385,7 +385,7 @@ class CompositionalObservable(Observable):
|
||||
low = np.hstack((low, space.low.reshape((-1))))
|
||||
high = np.hstack((high, space.high.reshape((-1))))
|
||||
return spaces.Box(low=low, high=high,
|
||||
shape=(num,), dtype=np.float32)
|
||||
shape=(num,), dtype=np.float64)
|
||||
|
||||
def get_observation(self):
|
||||
o = [obs.get_observation().reshape((-1))
|
||||
|
Loading…
Reference in New Issue
Block a user