Wrong dimensions given by StateObservable

This commit is contained in:
Dominik Moritz Roth 2022-08-22 18:08:01 +02:00
parent 291c9c6320
commit d4195a3f37

View File

@ -219,7 +219,7 @@ class StateObservable(Observable):
def get_observation_space(self): def get_observation_space(self):
self.env.reset() self.env.reset()
num = len(self.entities)*2+len(self._timeoutEntities) + \ num = len(self.entities)*2+len(self._timeoutEntities) + \
self.speedAgent + self.include_rand self.speedAgent*2 + self.include_rand
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
shape=(num,), dtype=np.float32) shape=(num,), dtype=np.float32)
@ -241,8 +241,8 @@ class StateObservable(Observable):
for entity in self._timeoutEntities: for entity in self._timeoutEntities:
obs.append(entity.active) obs.append(entity.active)
if self.speedAgent: if self.speedAgent:
obs.append(self.env.speed[0]) obs.append(self.env.agent.speed[0])
obs.append(self.env.speed[1]) obs.append(self.env.agent.speed[1])
if self.include_rand: if self.include_rand:
obs.append(self.env.random()) obs.append(self.env.random())
self.obs = obs self.obs = obs