Also imlemented RayCasting for Rectangles
This commit is contained in:
		
							parent
							
								
									c34d266ea5
								
							
						
					
					
						commit
						5cedffa473
					
				@ -417,7 +417,7 @@ class ColumbusTest3_1(ColumbusEnv):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ColumbusTestRect(ColumbusEnv):
 | 
			
		||||
    def __init__(self, observable=observables.Observable(), fps=30, aux_reward_max=1, **kw):
 | 
			
		||||
    def __init__(self, observable=observables.RayObservable(), fps=30, aux_reward_max=1, **kw):
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            observable=observable, fps=fps, env_seed=3.3, aux_reward_max=aux_reward_max, controll_type='ACC', **kw)
 | 
			
		||||
        self.start_pos = [0.5, 0.5]
 | 
			
		||||
 | 
			
		||||
@ -115,12 +115,20 @@ class RayObservable(Observable):
 | 
			
		||||
                    if not self.env.torus_topology and (0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[1] >= self.env.width):
 | 
			
		||||
                        return True
 | 
			
		||||
                else:
 | 
			
		||||
                    if entity.shape != 'circle':
 | 
			
		||||
                        raise Exception('Can only raycast circular entities!')
 | 
			
		||||
                    sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \
 | 
			
		||||
                        + (pos[1]-entity.pos[1]*self.env.height)**2
 | 
			
		||||
                    if sq_dist < entity.radius**2:
 | 
			
		||||
                        return True
 | 
			
		||||
                    if entity.shape == 'circle':
 | 
			
		||||
                        sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \
 | 
			
		||||
                            + (pos[1]-entity.pos[1]*self.env.height)**2
 | 
			
		||||
                        if sq_dist < entity.radius**2:
 | 
			
		||||
                            return True
 | 
			
		||||
                    elif entity.shape == 'rect':
 | 
			
		||||
                        dot = entities.CircularEntity(self.env)
 | 
			
		||||
                        dot.radius = 1
 | 
			
		||||
                        dot.pos = pos[0]/self.env.width, pos[1]/self.env.height
 | 
			
		||||
                        if sum(dot._get_crash_force_dir(entity)) != 0:
 | 
			
		||||
                            return True
 | 
			
		||||
                    else:
 | 
			
		||||
                        raise Exception(
 | 
			
		||||
                            'Can only raycast circular and rectangular entities!')
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def _get_possible_entities(self):
 | 
			
		||||
@ -128,9 +136,16 @@ class RayObservable(Observable):
 | 
			
		||||
        if entities.Void in self.chans or self.env.void_barrier:
 | 
			
		||||
            entities_l.append(entities.Void(self.env))
 | 
			
		||||
        for entity in self.env.entities:
 | 
			
		||||
            if entity.shape == 'rect':
 | 
			
		||||
                radius = (entity.width/2 + entity.height/2)*1.0
 | 
			
		||||
            elif entity.shape == 'circle':
 | 
			
		||||
                radius = entity.radius
 | 
			
		||||
            else:
 | 
			
		||||
                raise Exception(
 | 
			
		||||
                    'Can only raycast circular and rectangular entities!')
 | 
			
		||||
            sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \
 | 
			
		||||
                + ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2
 | 
			
		||||
            if sq_dist <= (entity.radius + self.env.agent.radius + self.ray_len)**2:
 | 
			
		||||
            if sq_dist <= (radius + self.env.agent.radius + self.ray_len)**2:
 | 
			
		||||
                entities_l.append(entity)  # cannot use yield here!
 | 
			
		||||
        return entities_l
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user