From 3dc5bed9f4f6c93b90ea83b10f811e2ae8950bf0 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 19 Jun 2022 22:58:30 +0200 Subject: [PATCH] Bug fixes for RayTracing --- columbus/observables.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/columbus/observables.py b/columbus/observables.py index a61d66a..66b7318 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -112,11 +112,13 @@ class RayObservable(Observable): return False def _get_possible_entities(self): + entities = [] for entity in self.env.entities: sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \ + ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2 - if sq_dist <= (entity.radius + self.env.agent.radius)**2: - yield entity + if sq_dist <= (entity.radius + self.env.agent.radius + self.ray_len)**2: + entities.append(entity) # cannot use yield here! + return entities def get_observation(self): entities = self._get_possible_entities() @@ -127,7 +129,7 @@ class RayObservable(Observable): for s in range(self.num_steps): if s > occ_dist: break - sx, sy = s*hx/self.num_steps, s*hy/self.num_steps + sx, sy = (s+1)*hx/self.num_steps, (s+1)*hy/self.num_steps rx, ry = sx + \ self.env.agent.pos[0]*self.env.width, sy + \ self.env.agent.pos[1]*self.env.height @@ -142,7 +144,7 @@ class RayObservable(Observable): for c, entity_type in enumerate(self.chans): for r, (hx, hy) in enumerate(self._get_ray_heads()): s = self.num_steps - self.rays[r, c] - sx, sy = s*hx/self.num_steps, s*hy/self.num_steps + sx, sy = (s+1)*hx/self.num_steps, (s+1)*hy/self.num_steps rx, ry = sx + \ self.env.agent.pos[0]*self.env.width, sy + \ self.env.agent.pos[1]*self.env.height