Made RayTracing more performant (exclude distant entities)
This commit is contained in:
parent
e3a1044cb3
commit
6898c2deb5
@ -100,8 +100,8 @@ class RayObservable(Observable):
|
|||||||
rad = 2*math.pi/self.num_rays*i
|
rad = 2*math.pi/self.num_rays*i
|
||||||
yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad)
|
yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad)
|
||||||
|
|
||||||
def _check_collision(self, pos, entity_type):
|
def _check_collision(self, pos, entity_type, entities):
|
||||||
for entity in self.env.entities:
|
for entity in entities:
|
||||||
if isinstance(entity, entity_type):
|
if isinstance(entity, entity_type):
|
||||||
if entity.shape != 'circle':
|
if entity.shape != 'circle':
|
||||||
raise Exception('Can only raycast circular entities!')
|
raise Exception('Can only raycast circular entities!')
|
||||||
@ -111,7 +111,15 @@ class RayObservable(Observable):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _get_possible_entities(self):
|
||||||
|
for entity in self.env.entities:
|
||||||
|
sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \
|
||||||
|
+ ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2
|
||||||
|
if sq_dist <= (entity.radius + self.env.agent.radius)**2:
|
||||||
|
yield entity
|
||||||
|
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
|
entities = self._get_possible_entities()
|
||||||
self.rays = np.zeros((self.num_rays, self.num_chans))
|
self.rays = np.zeros((self.num_rays, self.num_chans))
|
||||||
for r, (hx, hy) in enumerate(self._get_ray_heads()):
|
for r, (hx, hy) in enumerate(self._get_ray_heads()):
|
||||||
occ_dist = self.num_steps
|
occ_dist = self.num_steps
|
||||||
@ -123,7 +131,7 @@ class RayObservable(Observable):
|
|||||||
rx, ry = sx + \
|
rx, ry = sx + \
|
||||||
self.env.agent.pos[0]*self.env.width, sy + \
|
self.env.agent.pos[0]*self.env.width, sy + \
|
||||||
self.env.agent.pos[1]*self.env.height
|
self.env.agent.pos[1]*self.env.height
|
||||||
if self._check_collision((rx, ry), entity_type):
|
if self._check_collision((rx, ry), entity_type, entities):
|
||||||
self.rays[r, c] = self.num_steps-s
|
self.rays[r, c] = self.num_steps-s
|
||||||
if self.occlusion:
|
if self.occlusion:
|
||||||
occ_dist = s
|
occ_dist = s
|
||||||
|
Loading…
Reference in New Issue
Block a user