From 4bfa15b3625ff58d1c7d3f62e41755b703fcfa3a Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 20 Jun 2022 23:11:11 +0200 Subject: [PATCH] Bug fixes and minor additions to observables --- columbus/observables.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/columbus/observables.py b/columbus/observables.py index 90b652d..cbd4547 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -57,6 +57,10 @@ class CnnObservable(Observable): rect = pygame.Rect(cx, cy, cw, ch) snap = self.env.surface.subsurface(rect) self.snap = pygame.Surface((self.in_width, self.in_height)) + if self.env.void_barrier: + col = (223, 0, 0) + else: + col = (50, 50, 50) pygame.draw.rect(self.snap, (50, 50, 50), pygame.Rect(0, 0, self.in_width, self.in_height)) self.snap.blit(snap, (cx - x, cy - y)) @@ -82,18 +86,19 @@ def _clip(num, lower, upper): class RayObservable(Observable): - def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward, entities.Void], ray_len=256): + def __init__(self, num_rays=16, chans=[entities.Enemy, entities.Reward], ray_len=256, num_steps=64, include_rand=False): super(RayObservable, self).__init__() self.num_rays = num_rays self.chans = chans self.num_chans = len(chans) self.ray_len = ray_len - self.num_steps = 32 # max = 255 + self.num_steps = num_steps # max = 255 self.occlusion = True # previous channels block view onto later channels + self.include_rand = include_rand def get_observation_space(self): return spaces.Box(low=0, high=self.num_steps, - shape=(self.num_rays, self.num_chans), dtype=np.uint8) + shape=(self.num_rays+self.include_rand, self.num_chans), dtype=np.uint8) def _get_ray_heads(self): for i in range(self.num_rays): @@ -102,12 +107,10 @@ class RayObservable(Observable): def _check_collision(self, pos, entity_type, entities_l): for entity in entities_l: - if isinstance(entity, entity_type): + if isinstance(entity, entity_type) or (self.env.void_barrier and isinstance(entity, entities.Void) and entity_type == entities.Enemy): if isinstance(entity, entities.Void): - hit = 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height - if hit: - print(pos) - return hit + if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height: + return True else: if entity.shape != 'circle': raise Exception('Can only raycast circular entities!') @@ -119,7 +122,7 @@ class RayObservable(Observable): def _get_possible_entities(self): entities_l = [] - if entities.Void in self.chans: + if entities.Void in self.chans or self.env.void_barrier: entities_l.append(entities.Void(self.env)) for entity in self.env.entities: sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \ @@ -130,7 +133,10 @@ class RayObservable(Observable): def get_observation(self): entities = self._get_possible_entities() - self.rays = np.zeros((self.num_rays, self.num_chans)) + self.rays = np.zeros((self.num_rays+self.include_rand, self.num_chans)) + if self.include_rand: + for c in range(self.num_chans): + self.rays[-1, c] = self.env.random() for r, (hx, hy) in enumerate(self._get_ray_heads()): occ_dist = self.num_steps for c, entity_type in enumerate(self.chans): @@ -162,7 +168,7 @@ class RayObservable(Observable): pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0) -def StateObservable(Observable): +class StateObservable(Observable): def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True): super(StateObservable, self).__init__() self._entities = None @@ -181,6 +187,7 @@ def StateObservable(Observable): def entities(self): if self._entities: return self._entities + self.env.setup() self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities self.enemysWhitelist = self.enemysWhitelist or self.env.entities self._entities = [] @@ -190,12 +197,12 @@ def StateObservable(Observable): for entity in self.rewardsWhitelist: if isinstance(entity, entities.Reward): self._entities.append(entity) - if self.coordEnemys: + if self.coordsEnemys: for entity in self.enemysWhitelist: if isinstance(entity, entities.Enemy): if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier): self._entities.append(entity) - if self.rewardsTimeout: + if self.rewardsTimeouts: for entity in self.enemysWhitelist: if isinstance(entity, entities.TimeoutReward): self._timeoutEntities.append(entity) @@ -203,7 +210,7 @@ def StateObservable(Observable): def get_observation_space(self): return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, - shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent), dtype=np.float32) + shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent,), dtype=np.float32) def get_observation(self): obs = []