From 6898c2deb5da68d6306a2913ce0f16c9fb836d78 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 19 Jun 2022 21:47:35 +0200 Subject: [PATCH] Made RayTracing more performant (exclude distant entities) --- columbus/observables.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/columbus/observables.py b/columbus/observables.py index 726cc41..49d0395 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -100,8 +100,8 @@ class RayObservable(Observable): rad = 2*math.pi/self.num_rays*i yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad) - def _check_collision(self, pos, entity_type): - for entity in self.env.entities: + def _check_collision(self, pos, entity_type, entities): + for entity in entities: if isinstance(entity, entity_type): if entity.shape != 'circle': raise Exception('Can only raycast circular entities!') @@ -111,7 +111,15 @@ class RayObservable(Observable): return True return False + def _get_possible_entities(self): + for entity in self.env.entities: + sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \ + + ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2 + if sq_dist <= (entity.radius + self.env.agent.radius)**2: + yield entity + def get_observation(self): + entities = self._get_possible_entities() self.rays = np.zeros((self.num_rays, self.num_chans)) for r, (hx, hy) in enumerate(self._get_ray_heads()): occ_dist = self.num_steps @@ -123,7 +131,7 @@ class RayObservable(Observable): rx, ry = sx + \ self.env.agent.pos[0]*self.env.width, sy + \ self.env.agent.pos[1]*self.env.height - if self._check_collision((rx, ry), entity_type): + if self._check_collision((rx, ry), entity_type, entities): self.rays[r, c] = self.num_steps-s if self.occlusion: occ_dist = s