diff --git a/columbus/entities.py b/columbus/entities.py index 56fdc20..4ec910e 100644 --- a/columbus/entities.py +++ b/columbus/entities.py @@ -193,3 +193,9 @@ class TimeoutReward(OnceReward): self.env.new_abs_reward += self.reward self.set_avaible(False) self.env.timers.append((self.timeout, self.set_avaible, True)) + + +class Void(): + def __init__(self, env): + self.col = (50, 50, 50) + pass diff --git a/columbus/humanPlayer.py b/columbus/humanPlayer.py index 3dcceed..c2e6495 100644 --- a/columbus/humanPlayer.py +++ b/columbus/humanPlayer.py @@ -8,6 +8,7 @@ from observables import Observable, CnnObservable def main(): env = ColumbusTest3_1() + env = ColumbusTestRay(hide_map=True) env.start_pos = [0.6, 0.3] playEnv(env) env.close() diff --git a/columbus/observables.py b/columbus/observables.py index 66b7318..90b652d 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -82,7 +82,7 @@ def _clip(num, lower, upper): class RayObservable(Observable): - def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward], ray_len=256): + def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward, entities.Void], ray_len=256): super(RayObservable, self).__init__() self.num_rays = num_rays self.chans = chans @@ -100,25 +100,33 @@ class RayObservable(Observable): rad = 2*math.pi/self.num_rays*i yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad) - def _check_collision(self, pos, entity_type, entities): - for entity in entities: + def _check_collision(self, pos, entity_type, entities_l): + for entity in entities_l: if isinstance(entity, entity_type): - if entity.shape != 'circle': - raise Exception('Can only raycast circular entities!') - sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \ - + (pos[1]-entity.pos[1]*self.env.height)**2 - if sq_dist < entity.radius**2: - return True + if isinstance(entity, entities.Void): + hit = 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height + if hit: + print(pos) + return hit + else: + if entity.shape != 'circle': + raise Exception('Can only raycast circular entities!') + sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \ + + (pos[1]-entity.pos[1]*self.env.height)**2 + if sq_dist < entity.radius**2: + return True return False def _get_possible_entities(self): - entities = [] + entities_l = [] + if entities.Void in self.chans: + entities_l.append(entities.Void(self.env)) for entity in self.env.entities: sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \ + ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2 if sq_dist <= (entity.radius + self.env.agent.radius + self.ray_len)**2: - entities.append(entity) # cannot use yield here! - return entities + entities_l.append(entity) # cannot use yield here! + return entities_l def get_observation(self): entities = self._get_possible_entities()