diff --git a/columbus/observables.py b/columbus/observables.py index 49d0395..a61d66a 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -150,3 +150,73 @@ class RayObservable(Observable): col = entity_type(self.env).col col = int(col[0]/2), int(col[1]/2), int(col[2]/2) pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0) + + +def StateObservable(Observable): + def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True): + super(StateObservable, self).__init__() + self._entities = None + self._timeoutEntities = [] + self.coordsAgent = coordsAgent + self.speedAgent = speedAgent + self.coordsRelativeToAgent = coordsRelativeToAgent + self.coordRewards = coordsRewards + self.rewardsWhitelist = rewardsWhitelist + self.coordsEnemys = coordsEnemys + self.enemysWhitelist = enemysWhitelist + self.enemysNoBarriers = enemysNoBarriers + self.rewardsTimeouts = rewardsTimeouts + + @property + def entities(self): + if self._entities: + return self._entities + self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities + self.enemysWhitelist = self.enemysWhitelist or self.env.entities + self._entities = [] + if self.coordsAgent: + self._entities.append(self.env.agent) + if self.coordRewards: + for entity in self.rewardsWhitelist: + if isinstance(entity, entities.Reward): + self._entities.append(entity) + if self.coordEnemys: + for entity in self.enemysWhitelist: + if isinstance(entity, entities.Enemy): + if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier): + self._entities.append(entity) + if self.rewardsTimeout: + for entity in self.enemysWhitelist: + if isinstance(entity, entities.TimeoutReward): + self._timeoutEntities.append(entity) + return self._entities + + def get_observation_space(self): + return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, + shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent), dtype=np.float32) + + def get_observation(self): + obs = [] + if self.coordsRelativeToAgent: + for entity in self.entities: + if not isinstance(entity, entities.Agent): + obs.append(entity.pos[0] - self.env.agent.pos[0]) + obs.append(entity.pos[1] - self.env.agent.pos[1]) + else: + obs.append(entity.pos[0]) + obs.append(entity.pos[1]) + else: + for entity in self.entities: + obs.append(entity.pos[0]) + obs.append(entity.pos[1]) + + for entity in self._timeoutEntities: + obs.append(entity.active) + if self.speedAgent: + obs.append(self.env.speed[0]) + obs.append(self.env.speed[1]) + self.obs = obs + return np.array(obs) + + def draw(self): + pass