Fixed bugs with StateObservable
This commit is contained in:
parent
624cefff8d
commit
983d5071a4
@ -186,32 +186,36 @@ class StateObservable(Observable):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def entities(self):
|
def entities(self):
|
||||||
if self._entities:
|
if not self._entities == None:
|
||||||
return self._entities
|
return self._entities
|
||||||
self.env.setup()
|
print('Building StateGetters')
|
||||||
self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
||||||
self.enemysWhitelist = self.enemysWhitelist or self.env.entities
|
enemysWhitelist = self.enemysWhitelist or self.env.entities
|
||||||
self._entities = []
|
self._entities = []
|
||||||
if self.coordsAgent:
|
if self.coordsAgent:
|
||||||
self._entities.append(self.env.agent)
|
self._entities.append(self.env.agent)
|
||||||
if self.coordRewards:
|
if self.coordRewards:
|
||||||
for entity in self.rewardsWhitelist:
|
for entity in rewardsWhitelist:
|
||||||
if isinstance(entity, entities.Reward):
|
if isinstance(entity, entities.Reward):
|
||||||
self._entities.append(entity)
|
self._entities.append(entity)
|
||||||
if self.coordsEnemys:
|
if self.coordsEnemys:
|
||||||
for entity in self.enemysWhitelist:
|
for entity in enemysWhitelist:
|
||||||
if isinstance(entity, entities.Enemy):
|
if isinstance(entity, entities.Enemy):
|
||||||
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
||||||
self._entities.append(entity)
|
self._entities.append(entity)
|
||||||
if self.rewardsTimeouts:
|
if self.rewardsTimeouts:
|
||||||
for entity in self.enemysWhitelist:
|
for entity in enemysWhitelist:
|
||||||
if isinstance(entity, entities.TimeoutReward):
|
if isinstance(entity, entities.TimeoutReward):
|
||||||
self._timeoutEntities.append(entity)
|
self._timeoutEntities.append(entity)
|
||||||
|
print(len(self._entities))
|
||||||
return self._entities
|
return self._entities
|
||||||
|
|
||||||
def get_observation_space(self):
|
def get_observation_space(self):
|
||||||
|
self.env.setup()
|
||||||
|
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
||||||
|
self.speedAgent + self.include_rand
|
||||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||||
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent + self.include_rand,), dtype=np.float32)
|
shape=(num,), dtype=np.float32)
|
||||||
|
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
obs = []
|
obs = []
|
||||||
@ -239,4 +243,10 @@ class StateObservable(Observable):
|
|||||||
return np.array(obs)
|
return np.array(obs)
|
||||||
|
|
||||||
def draw(self):
|
def draw(self):
|
||||||
pass
|
for i in range(int(len(self.obs)/2)):
|
||||||
|
x, y = self.obs[i*2], self.obs[i*2+1]
|
||||||
|
col = self.entities[i].col
|
||||||
|
pygame.draw.circle(self.env.screen, col,
|
||||||
|
(0, y*self.env.height), 1, width=0)
|
||||||
|
pygame.draw.circle(self.env.screen, col,
|
||||||
|
(x*self.env.width, 0), 1, width=0)
|
||||||
|
Loading…
Reference in New Issue
Block a user