Wrong dimensions given by StateObservable
This commit is contained in:
parent
291c9c6320
commit
d4195a3f37
@ -219,7 +219,7 @@ class StateObservable(Observable):
|
||||
def get_observation_space(self):
|
||||
self.env.reset()
|
||||
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
||||
self.speedAgent + self.include_rand
|
||||
self.speedAgent*2 + self.include_rand
|
||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||
shape=(num,), dtype=np.float32)
|
||||
|
||||
@ -241,8 +241,8 @@ class StateObservable(Observable):
|
||||
for entity in self._timeoutEntities:
|
||||
obs.append(entity.active)
|
||||
if self.speedAgent:
|
||||
obs.append(self.env.speed[0])
|
||||
obs.append(self.env.speed[1])
|
||||
obs.append(self.env.agent.speed[0])
|
||||
obs.append(self.env.agent.speed[1])
|
||||
if self.include_rand:
|
||||
obs.append(self.env.random())
|
||||
self.obs = obs
|
||||
|
Loading…
Reference in New Issue
Block a user