Added StateObservable
This commit is contained in:
parent
6898c2deb5
commit
75732ba960
@ -150,3 +150,73 @@ class RayObservable(Observable):
|
||||
col = entity_type(self.env).col
|
||||
col = int(col[0]/2), int(col[1]/2), int(col[2]/2)
|
||||
pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0)
|
||||
|
||||
|
||||
def StateObservable(Observable):
|
||||
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True):
|
||||
super(StateObservable, self).__init__()
|
||||
self._entities = None
|
||||
self._timeoutEntities = []
|
||||
self.coordsAgent = coordsAgent
|
||||
self.speedAgent = speedAgent
|
||||
self.coordsRelativeToAgent = coordsRelativeToAgent
|
||||
self.coordRewards = coordsRewards
|
||||
self.rewardsWhitelist = rewardsWhitelist
|
||||
self.coordsEnemys = coordsEnemys
|
||||
self.enemysWhitelist = enemysWhitelist
|
||||
self.enemysNoBarriers = enemysNoBarriers
|
||||
self.rewardsTimeouts = rewardsTimeouts
|
||||
|
||||
@property
|
||||
def entities(self):
|
||||
if self._entities:
|
||||
return self._entities
|
||||
self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
||||
self.enemysWhitelist = self.enemysWhitelist or self.env.entities
|
||||
self._entities = []
|
||||
if self.coordsAgent:
|
||||
self._entities.append(self.env.agent)
|
||||
if self.coordRewards:
|
||||
for entity in self.rewardsWhitelist:
|
||||
if isinstance(entity, entities.Reward):
|
||||
self._entities.append(entity)
|
||||
if self.coordEnemys:
|
||||
for entity in self.enemysWhitelist:
|
||||
if isinstance(entity, entities.Enemy):
|
||||
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
||||
self._entities.append(entity)
|
||||
if self.rewardsTimeout:
|
||||
for entity in self.enemysWhitelist:
|
||||
if isinstance(entity, entities.TimeoutReward):
|
||||
self._timeoutEntities.append(entity)
|
||||
return self._entities
|
||||
|
||||
def get_observation_space(self):
|
||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent), dtype=np.float32)
|
||||
|
||||
def get_observation(self):
|
||||
obs = []
|
||||
if self.coordsRelativeToAgent:
|
||||
for entity in self.entities:
|
||||
if not isinstance(entity, entities.Agent):
|
||||
obs.append(entity.pos[0] - self.env.agent.pos[0])
|
||||
obs.append(entity.pos[1] - self.env.agent.pos[1])
|
||||
else:
|
||||
obs.append(entity.pos[0])
|
||||
obs.append(entity.pos[1])
|
||||
else:
|
||||
for entity in self.entities:
|
||||
obs.append(entity.pos[0])
|
||||
obs.append(entity.pos[1])
|
||||
|
||||
for entity in self._timeoutEntities:
|
||||
obs.append(entity.active)
|
||||
if self.speedAgent:
|
||||
obs.append(self.env.speed[0])
|
||||
obs.append(self.env.speed[1])
|
||||
self.obs = obs
|
||||
return np.array(obs)
|
||||
|
||||
def draw(self):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user