Also imlemented RayCasting for Rectangles

This commit is contained in:
Dominik Moritz Roth 2022-09-13 22:25:29 +02:00
parent c34d266ea5
commit 5cedffa473
2 changed files with 23 additions and 8 deletions

View File

@ -417,7 +417,7 @@ class ColumbusTest3_1(ColumbusEnv):
class ColumbusTestRect(ColumbusEnv): class ColumbusTestRect(ColumbusEnv):
def __init__(self, observable=observables.Observable(), fps=30, aux_reward_max=1, **kw): def __init__(self, observable=observables.RayObservable(), fps=30, aux_reward_max=1, **kw):
super().__init__( super().__init__(
observable=observable, fps=fps, env_seed=3.3, aux_reward_max=aux_reward_max, controll_type='ACC', **kw) observable=observable, fps=fps, env_seed=3.3, aux_reward_max=aux_reward_max, controll_type='ACC', **kw)
self.start_pos = [0.5, 0.5] self.start_pos = [0.5, 0.5]

View File

@ -115,12 +115,20 @@ class RayObservable(Observable):
if not self.env.torus_topology and (0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[1] >= self.env.width): if not self.env.torus_topology and (0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[1] >= self.env.width):
return True return True
else: else:
if entity.shape != 'circle': if entity.shape == 'circle':
raise Exception('Can only raycast circular entities!') sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \
sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \ + (pos[1]-entity.pos[1]*self.env.height)**2
+ (pos[1]-entity.pos[1]*self.env.height)**2 if sq_dist < entity.radius**2:
if sq_dist < entity.radius**2: return True
return True elif entity.shape == 'rect':
dot = entities.CircularEntity(self.env)
dot.radius = 1
dot.pos = pos[0]/self.env.width, pos[1]/self.env.height
if sum(dot._get_crash_force_dir(entity)) != 0:
return True
else:
raise Exception(
'Can only raycast circular and rectangular entities!')
return False return False
def _get_possible_entities(self): def _get_possible_entities(self):
@ -128,9 +136,16 @@ class RayObservable(Observable):
if entities.Void in self.chans or self.env.void_barrier: if entities.Void in self.chans or self.env.void_barrier:
entities_l.append(entities.Void(self.env)) entities_l.append(entities.Void(self.env))
for entity in self.env.entities: for entity in self.env.entities:
if entity.shape == 'rect':
radius = (entity.width/2 + entity.height/2)*1.0
elif entity.shape == 'circle':
radius = entity.radius
else:
raise Exception(
'Can only raycast circular and rectangular entities!')
sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \ sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \
+ ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2 + ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2
if sq_dist <= (entity.radius + self.env.agent.radius + self.ray_len)**2: if sq_dist <= (radius + self.env.agent.radius + self.ray_len)**2:
entities_l.append(entity) # cannot use yield here! entities_l.append(entity) # cannot use yield here!
return entities_l return entities_l