Implemented RayObserver

This commit is contained in:
Dominik Moritz Roth 2022-06-19 20:33:45 +02:00
parent d731695f2a
commit e3a1044cb3

View File

@ -1,6 +1,8 @@
from gym import spaces
import numpy as np
import pygame
import math
from columbus import entities
class Observable():
@ -16,6 +18,12 @@ class Observable():
return spaces.Box(low=0, high=255,
shape=(1,), dtype=np.uint8)
def get_observation(self):
return False
def draw(self):
pass
class CnnObservable(Observable):
def __init__(self, in_width=256, in_height=256, out_width=32, out_height=32, draw_width=128, draw_height=128, smooth_scaling=True):
@ -36,6 +44,8 @@ class CnnObservable(Observable):
shape=(self.out_width, self.out_height, 3), dtype=np.uint8)
def get_observation(self):
if not self.env._rendered:
self.env.render(dont_show=True)
self.env._ensure_surface()
x, y = self.env.agent.pos[0]*self.env.width - self.in_width / \
2, self.env.agent.pos[1]*self.env.height - self.in_height/2
@ -69,3 +79,66 @@ class CnnObservable(Observable):
def _clip(num, lower, upper):
return min(max(num, lower), upper)
class RayObservable(Observable):
def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward], ray_len=256):
super(RayObservable, self).__init__()
self.num_rays = num_rays
self.chans = chans
self.num_chans = len(chans)
self.ray_len = ray_len
self.num_steps = 32 # max = 255
self.occlusion = True # previous channels block view onto later channels
def get_observation_space(self):
return spaces.Box(low=0, high=self.num_steps,
shape=(self.num_rays, self.num_chans), dtype=np.uint8)
def _get_ray_heads(self):
for i in range(self.num_rays):
rad = 2*math.pi/self.num_rays*i
yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad)
def _check_collision(self, pos, entity_type):
for entity in self.env.entities:
if isinstance(entity, entity_type):
if entity.shape != 'circle':
raise Exception('Can only raycast circular entities!')
sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \
+ (pos[1]-entity.pos[1]*self.env.height)**2
if sq_dist < entity.radius**2:
return True
return False
def get_observation(self):
self.rays = np.zeros((self.num_rays, self.num_chans))
for r, (hx, hy) in enumerate(self._get_ray_heads()):
occ_dist = self.num_steps
for c, entity_type in enumerate(self.chans):
for s in range(self.num_steps):
if s > occ_dist:
break
sx, sy = s*hx/self.num_steps, s*hy/self.num_steps
rx, ry = sx + \
self.env.agent.pos[0]*self.env.width, sy + \
self.env.agent.pos[1]*self.env.height
if self._check_collision((rx, ry), entity_type):
self.rays[r, c] = self.num_steps-s
if self.occlusion:
occ_dist = s
break
return self.rays
def draw(self):
for c, entity_type in enumerate(self.chans):
for r, (hx, hy) in enumerate(self._get_ray_heads()):
s = self.num_steps - self.rays[r, c]
sx, sy = s*hx/self.num_steps, s*hy/self.num_steps
rx, ry = sx + \
self.env.agent.pos[0]*self.env.width, sy + \
self.env.agent.pos[1]*self.env.height
# TODO: How stupid do I want to code?
col = entity_type(self.env).col
col = int(col[0]/2), int(col[1]/2), int(col[2]/2)
pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0)