From e3a1044cb3332bf3dfa45722320679755fa960af Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 19 Jun 2022 20:33:45 +0200 Subject: [PATCH] Implemented RayObserver --- columbus/observables.py | 73 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/columbus/observables.py b/columbus/observables.py index fac929b..726cc41 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -1,6 +1,8 @@ from gym import spaces import numpy as np import pygame +import math +from columbus import entities class Observable(): @@ -16,6 +18,12 @@ class Observable(): return spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8) + def get_observation(self): + return False + + def draw(self): + pass + class CnnObservable(Observable): def __init__(self, in_width=256, in_height=256, out_width=32, out_height=32, draw_width=128, draw_height=128, smooth_scaling=True): @@ -36,6 +44,8 @@ class CnnObservable(Observable): shape=(self.out_width, self.out_height, 3), dtype=np.uint8) def get_observation(self): + if not self.env._rendered: + self.env.render(dont_show=True) self.env._ensure_surface() x, y = self.env.agent.pos[0]*self.env.width - self.in_width / \ 2, self.env.agent.pos[1]*self.env.height - self.in_height/2 @@ -69,3 +79,66 @@ class CnnObservable(Observable): def _clip(num, lower, upper): return min(max(num, lower), upper) + + +class RayObservable(Observable): + def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward], ray_len=256): + super(RayObservable, self).__init__() + self.num_rays = num_rays + self.chans = chans + self.num_chans = len(chans) + self.ray_len = ray_len + self.num_steps = 32 # max = 255 + self.occlusion = True # previous channels block view onto later channels + + def get_observation_space(self): + return spaces.Box(low=0, high=self.num_steps, + shape=(self.num_rays, self.num_chans), dtype=np.uint8) + + def _get_ray_heads(self): + for i in range(self.num_rays): + rad = 2*math.pi/self.num_rays*i + yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad) + + def _check_collision(self, pos, entity_type): + for entity in self.env.entities: + if isinstance(entity, entity_type): + if entity.shape != 'circle': + raise Exception('Can only raycast circular entities!') + sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \ + + (pos[1]-entity.pos[1]*self.env.height)**2 + if sq_dist < entity.radius**2: + return True + return False + + def get_observation(self): + self.rays = np.zeros((self.num_rays, self.num_chans)) + for r, (hx, hy) in enumerate(self._get_ray_heads()): + occ_dist = self.num_steps + for c, entity_type in enumerate(self.chans): + for s in range(self.num_steps): + if s > occ_dist: + break + sx, sy = s*hx/self.num_steps, s*hy/self.num_steps + rx, ry = sx + \ + self.env.agent.pos[0]*self.env.width, sy + \ + self.env.agent.pos[1]*self.env.height + if self._check_collision((rx, ry), entity_type): + self.rays[r, c] = self.num_steps-s + if self.occlusion: + occ_dist = s + break + return self.rays + + def draw(self): + for c, entity_type in enumerate(self.chans): + for r, (hx, hy) in enumerate(self._get_ray_heads()): + s = self.num_steps - self.rays[r, c] + sx, sy = s*hx/self.num_steps, s*hy/self.num_steps + rx, ry = sx + \ + self.env.agent.pos[0]*self.env.width, sy + \ + self.env.agent.pos[1]*self.env.height + # TODO: How stupid do I want to code? + col = entity_type(self.env).col + col = int(col[0]/2), int(col[1]/2), int(col[2]/2) + pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0)