Fixed bugs with StateObservable

This commit is contained in:
Dominik Moritz Roth 2022-06-21 22:29:50 +02:00
parent 624cefff8d
commit 983d5071a4

View File

@ -186,32 +186,36 @@ class StateObservable(Observable):
@property @property
def entities(self): def entities(self):
if self._entities: if not self._entities == None:
return self._entities return self._entities
self.env.setup() print('Building StateGetters')
self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities rewardsWhitelist = self.rewardsWhitelist or self.env.entities
self.enemysWhitelist = self.enemysWhitelist or self.env.entities enemysWhitelist = self.enemysWhitelist or self.env.entities
self._entities = [] self._entities = []
if self.coordsAgent: if self.coordsAgent:
self._entities.append(self.env.agent) self._entities.append(self.env.agent)
if self.coordRewards: if self.coordRewards:
for entity in self.rewardsWhitelist: for entity in rewardsWhitelist:
if isinstance(entity, entities.Reward): if isinstance(entity, entities.Reward):
self._entities.append(entity) self._entities.append(entity)
if self.coordsEnemys: if self.coordsEnemys:
for entity in self.enemysWhitelist: for entity in enemysWhitelist:
if isinstance(entity, entities.Enemy): if isinstance(entity, entities.Enemy):
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier): if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
self._entities.append(entity) self._entities.append(entity)
if self.rewardsTimeouts: if self.rewardsTimeouts:
for entity in self.enemysWhitelist: for entity in enemysWhitelist:
if isinstance(entity, entities.TimeoutReward): if isinstance(entity, entities.TimeoutReward):
self._timeoutEntities.append(entity) self._timeoutEntities.append(entity)
print(len(self._entities))
return self._entities return self._entities
def get_observation_space(self): def get_observation_space(self):
self.env.setup()
num = len(self.entities)*2+len(self._timeoutEntities) + \
self.speedAgent + self.include_rand
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent + self.include_rand,), dtype=np.float32) shape=(num,), dtype=np.float32)
def get_observation(self): def get_observation(self):
obs = [] obs = []
@ -239,4 +243,10 @@ class StateObservable(Observable):
return np.array(obs) return np.array(obs)
def draw(self): def draw(self):
pass for i in range(int(len(self.obs)/2)):
x, y = self.obs[i*2], self.obs[i*2+1]
col = self.entities[i].col
pygame.draw.circle(self.env.screen, col,
(0, y*self.env.height), 1, width=0)
pygame.draw.circle(self.env.screen, col,
(x*self.env.width, 0), 1, width=0)