diff --git a/columbus/env.py b/columbus/env.py index c70bbea..6f82893 100644 --- a/columbus/env.py +++ b/columbus/env.py @@ -417,7 +417,7 @@ class ColumbusTest3_1(ColumbusEnv): class ColumbusTestRect(ColumbusEnv): - def __init__(self, observable=observables.Observable(), fps=30, aux_reward_max=1, **kw): + def __init__(self, observable=observables.RayObservable(), fps=30, aux_reward_max=1, **kw): super().__init__( observable=observable, fps=fps, env_seed=3.3, aux_reward_max=aux_reward_max, controll_type='ACC', **kw) self.start_pos = [0.5, 0.5] diff --git a/columbus/observables.py b/columbus/observables.py index 8224487..31e8174 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -115,12 +115,20 @@ class RayObservable(Observable): if not self.env.torus_topology and (0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[1] >= self.env.width): return True else: - if entity.shape != 'circle': - raise Exception('Can only raycast circular entities!') - sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \ - + (pos[1]-entity.pos[1]*self.env.height)**2 - if sq_dist < entity.radius**2: - return True + if entity.shape == 'circle': + sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \ + + (pos[1]-entity.pos[1]*self.env.height)**2 + if sq_dist < entity.radius**2: + return True + elif entity.shape == 'rect': + dot = entities.CircularEntity(self.env) + dot.radius = 1 + dot.pos = pos[0]/self.env.width, pos[1]/self.env.height + if sum(dot._get_crash_force_dir(entity)) != 0: + return True + else: + raise Exception( + 'Can only raycast circular and rectangular entities!') return False def _get_possible_entities(self): @@ -128,9 +136,16 @@ class RayObservable(Observable): if entities.Void in self.chans or self.env.void_barrier: entities_l.append(entities.Void(self.env)) for entity in self.env.entities: + if entity.shape == 'rect': + radius = (entity.width/2 + entity.height/2)*1.0 + elif entity.shape == 'circle': + radius = entity.radius + else: + raise Exception( + 'Can only raycast circular and rectangular entities!') sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \ + ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2 - if sq_dist <= (entity.radius + self.env.agent.radius + self.ray_len)**2: + if sq_dist <= (radius + self.env.agent.radius + self.ray_len)**2: entities_l.append(entity) # cannot use yield here! return entities_l