Added StateObservable

This commit is contained in:
Dominik Moritz Roth 2022-06-19 22:46:42 +02:00
parent 6898c2deb5
commit 75732ba960

View File

@ -150,3 +150,73 @@ class RayObservable(Observable):
col = entity_type(self.env).col col = entity_type(self.env).col
col = int(col[0]/2), int(col[1]/2), int(col[2]/2) col = int(col[0]/2), int(col[1]/2), int(col[2]/2)
pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0) pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0)
def StateObservable(Observable):
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True):
super(StateObservable, self).__init__()
self._entities = None
self._timeoutEntities = []
self.coordsAgent = coordsAgent
self.speedAgent = speedAgent
self.coordsRelativeToAgent = coordsRelativeToAgent
self.coordRewards = coordsRewards
self.rewardsWhitelist = rewardsWhitelist
self.coordsEnemys = coordsEnemys
self.enemysWhitelist = enemysWhitelist
self.enemysNoBarriers = enemysNoBarriers
self.rewardsTimeouts = rewardsTimeouts
@property
def entities(self):
if self._entities:
return self._entities
self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities
self.enemysWhitelist = self.enemysWhitelist or self.env.entities
self._entities = []
if self.coordsAgent:
self._entities.append(self.env.agent)
if self.coordRewards:
for entity in self.rewardsWhitelist:
if isinstance(entity, entities.Reward):
self._entities.append(entity)
if self.coordEnemys:
for entity in self.enemysWhitelist:
if isinstance(entity, entities.Enemy):
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
self._entities.append(entity)
if self.rewardsTimeout:
for entity in self.enemysWhitelist:
if isinstance(entity, entities.TimeoutReward):
self._timeoutEntities.append(entity)
return self._entities
def get_observation_space(self):
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent), dtype=np.float32)
def get_observation(self):
obs = []
if self.coordsRelativeToAgent:
for entity in self.entities:
if not isinstance(entity, entities.Agent):
obs.append(entity.pos[0] - self.env.agent.pos[0])
obs.append(entity.pos[1] - self.env.agent.pos[1])
else:
obs.append(entity.pos[0])
obs.append(entity.pos[1])
else:
for entity in self.entities:
obs.append(entity.pos[0])
obs.append(entity.pos[1])
for entity in self._timeoutEntities:
obs.append(entity.active)
if self.speedAgent:
obs.append(self.env.speed[0])
obs.append(self.env.speed[1])
self.obs = obs
return np.array(obs)
def draw(self):
pass