Wrong dimensions given by StateObservable
This commit is contained in:
parent
291c9c6320
commit
d4195a3f37
@ -219,7 +219,7 @@ class StateObservable(Observable):
|
|||||||
def get_observation_space(self):
|
def get_observation_space(self):
|
||||||
self.env.reset()
|
self.env.reset()
|
||||||
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
||||||
self.speedAgent + self.include_rand
|
self.speedAgent*2 + self.include_rand
|
||||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||||
shape=(num,), dtype=np.float32)
|
shape=(num,), dtype=np.float32)
|
||||||
|
|
||||||
@ -241,8 +241,8 @@ class StateObservable(Observable):
|
|||||||
for entity in self._timeoutEntities:
|
for entity in self._timeoutEntities:
|
||||||
obs.append(entity.active)
|
obs.append(entity.active)
|
||||||
if self.speedAgent:
|
if self.speedAgent:
|
||||||
obs.append(self.env.speed[0])
|
obs.append(self.env.agent.speed[0])
|
||||||
obs.append(self.env.speed[1])
|
obs.append(self.env.agent.speed[1])
|
||||||
if self.include_rand:
|
if self.include_rand:
|
||||||
obs.append(self.env.random())
|
obs.append(self.env.random())
|
||||||
self.obs = obs
|
self.obs = obs
|
||||||
|
Loading…
Reference in New Issue
Block a user