Made RayTracing more performant (exclude distant entities)
This commit is contained in:
parent
e3a1044cb3
commit
6898c2deb5
@ -100,8 +100,8 @@ class RayObservable(Observable):
|
||||
rad = 2*math.pi/self.num_rays*i
|
||||
yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad)
|
||||
|
||||
def _check_collision(self, pos, entity_type):
|
||||
for entity in self.env.entities:
|
||||
def _check_collision(self, pos, entity_type, entities):
|
||||
for entity in entities:
|
||||
if isinstance(entity, entity_type):
|
||||
if entity.shape != 'circle':
|
||||
raise Exception('Can only raycast circular entities!')
|
||||
@ -111,7 +111,15 @@ class RayObservable(Observable):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_possible_entities(self):
|
||||
for entity in self.env.entities:
|
||||
sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \
|
||||
+ ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2
|
||||
if sq_dist <= (entity.radius + self.env.agent.radius)**2:
|
||||
yield entity
|
||||
|
||||
def get_observation(self):
|
||||
entities = self._get_possible_entities()
|
||||
self.rays = np.zeros((self.num_rays, self.num_chans))
|
||||
for r, (hx, hy) in enumerate(self._get_ray_heads()):
|
||||
occ_dist = self.num_steps
|
||||
@ -123,7 +131,7 @@ class RayObservable(Observable):
|
||||
rx, ry = sx + \
|
||||
self.env.agent.pos[0]*self.env.width, sy + \
|
||||
self.env.agent.pos[1]*self.env.height
|
||||
if self._check_collision((rx, ry), entity_type):
|
||||
if self._check_collision((rx, ry), entity_type, entities):
|
||||
self.rays[r, c] = self.num_steps-s
|
||||
if self.occlusion:
|
||||
occ_dist = s
|
||||
|
Loading…
Reference in New Issue
Block a user