Added StateObservable
This commit is contained in:
parent
6898c2deb5
commit
75732ba960
@ -150,3 +150,73 @@ class RayObservable(Observable):
|
|||||||
col = entity_type(self.env).col
|
col = entity_type(self.env).col
|
||||||
col = int(col[0]/2), int(col[1]/2), int(col[2]/2)
|
col = int(col[0]/2), int(col[1]/2), int(col[2]/2)
|
||||||
pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0)
|
pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0)
|
||||||
|
|
||||||
|
|
||||||
|
def StateObservable(Observable):
|
||||||
|
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True):
|
||||||
|
super(StateObservable, self).__init__()
|
||||||
|
self._entities = None
|
||||||
|
self._timeoutEntities = []
|
||||||
|
self.coordsAgent = coordsAgent
|
||||||
|
self.speedAgent = speedAgent
|
||||||
|
self.coordsRelativeToAgent = coordsRelativeToAgent
|
||||||
|
self.coordRewards = coordsRewards
|
||||||
|
self.rewardsWhitelist = rewardsWhitelist
|
||||||
|
self.coordsEnemys = coordsEnemys
|
||||||
|
self.enemysWhitelist = enemysWhitelist
|
||||||
|
self.enemysNoBarriers = enemysNoBarriers
|
||||||
|
self.rewardsTimeouts = rewardsTimeouts
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entities(self):
|
||||||
|
if self._entities:
|
||||||
|
return self._entities
|
||||||
|
self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
||||||
|
self.enemysWhitelist = self.enemysWhitelist or self.env.entities
|
||||||
|
self._entities = []
|
||||||
|
if self.coordsAgent:
|
||||||
|
self._entities.append(self.env.agent)
|
||||||
|
if self.coordRewards:
|
||||||
|
for entity in self.rewardsWhitelist:
|
||||||
|
if isinstance(entity, entities.Reward):
|
||||||
|
self._entities.append(entity)
|
||||||
|
if self.coordEnemys:
|
||||||
|
for entity in self.enemysWhitelist:
|
||||||
|
if isinstance(entity, entities.Enemy):
|
||||||
|
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
||||||
|
self._entities.append(entity)
|
||||||
|
if self.rewardsTimeout:
|
||||||
|
for entity in self.enemysWhitelist:
|
||||||
|
if isinstance(entity, entities.TimeoutReward):
|
||||||
|
self._timeoutEntities.append(entity)
|
||||||
|
return self._entities
|
||||||
|
|
||||||
|
def get_observation_space(self):
|
||||||
|
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||||
|
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent), dtype=np.float32)
|
||||||
|
|
||||||
|
def get_observation(self):
|
||||||
|
obs = []
|
||||||
|
if self.coordsRelativeToAgent:
|
||||||
|
for entity in self.entities:
|
||||||
|
if not isinstance(entity, entities.Agent):
|
||||||
|
obs.append(entity.pos[0] - self.env.agent.pos[0])
|
||||||
|
obs.append(entity.pos[1] - self.env.agent.pos[1])
|
||||||
|
else:
|
||||||
|
obs.append(entity.pos[0])
|
||||||
|
obs.append(entity.pos[1])
|
||||||
|
else:
|
||||||
|
for entity in self.entities:
|
||||||
|
obs.append(entity.pos[0])
|
||||||
|
obs.append(entity.pos[1])
|
||||||
|
|
||||||
|
for entity in self._timeoutEntities:
|
||||||
|
obs.append(entity.active)
|
||||||
|
if self.speedAgent:
|
||||||
|
obs.append(self.env.speed[0])
|
||||||
|
obs.append(self.env.speed[1])
|
||||||
|
self.obs = obs
|
||||||
|
return np.array(obs)
|
||||||
|
|
||||||
|
def draw(self):
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user