Extended output of StateObservable

This commit is contained in:
Dominik Moritz Roth 2022-06-29 12:42:49 +02:00
parent d998d816a1
commit 29854b2b5c

View File

@ -209,7 +209,7 @@ class StateObservable(Observable):
return self._entities return self._entities
def get_observation_space(self): def get_observation_space(self):
self.env.setup() self.env.reset()
num = len(self.entities)*2+len(self._timeoutEntities) + \ num = len(self.entities)*2+len(self._timeoutEntities) + \
self.speedAgent + self.include_rand self.speedAgent + self.include_rand
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
@ -241,10 +241,17 @@ class StateObservable(Observable):
return np.array(obs) return np.array(obs)
def draw(self): def draw(self):
ofs = (0 + self.env.height/2*self.coordsRelativeToAgent,
0 + self.env.width/2*self.coordsRelativeToAgent)
if self.coordsRelativeToAgent:
pygame.draw.circle(self.env.screen, self.env.agent.col,
(0, self.env.height/2), 3, width=0)
pygame.draw.circle(self.env.screen, self.env.agent.col,
(self.env.width/2, 0), 3, width=0)
for i in range(int(len(self.obs)/2)): for i in range(int(len(self.obs)/2)):
x, y = self.obs[i*2], self.obs[i*2+1] x, y = self.obs[i*2], self.obs[i*2+1]
col = self.entities[i].col col = self.entities[i].col
pygame.draw.circle(self.env.screen, col, pygame.draw.circle(self.env.screen, col,
(0, y*self.env.height), 1, width=0) (0, y*self.env.height+ofs[0]), 1, width=0)
pygame.draw.circle(self.env.screen, col, pygame.draw.circle(self.env.screen, col,
(x*self.env.width, 0), 1, width=0) (x*self.env.width+ofs[1], 0), 1, width=0)