Added entities.Void for RayTracing detection of Border

This commit is contained in:
Dominik Moritz Roth 2022-06-19 23:14:39 +02:00
parent 253ce6495e
commit 1c81c6bae9
3 changed files with 27 additions and 12 deletions

View File

@ -193,3 +193,9 @@ class TimeoutReward(OnceReward):
self.env.new_abs_reward += self.reward self.env.new_abs_reward += self.reward
self.set_avaible(False) self.set_avaible(False)
self.env.timers.append((self.timeout, self.set_avaible, True)) self.env.timers.append((self.timeout, self.set_avaible, True))
class Void():
def __init__(self, env):
self.col = (50, 50, 50)
pass

View File

@ -8,6 +8,7 @@ from observables import Observable, CnnObservable
def main(): def main():
env = ColumbusTest3_1() env = ColumbusTest3_1()
env = ColumbusTestRay(hide_map=True)
env.start_pos = [0.6, 0.3] env.start_pos = [0.6, 0.3]
playEnv(env) playEnv(env)
env.close() env.close()

View File

@ -82,7 +82,7 @@ def _clip(num, lower, upper):
class RayObservable(Observable): class RayObservable(Observable):
def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward], ray_len=256): def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward, entities.Void], ray_len=256):
super(RayObservable, self).__init__() super(RayObservable, self).__init__()
self.num_rays = num_rays self.num_rays = num_rays
self.chans = chans self.chans = chans
@ -100,25 +100,33 @@ class RayObservable(Observable):
rad = 2*math.pi/self.num_rays*i rad = 2*math.pi/self.num_rays*i
yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad) yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad)
def _check_collision(self, pos, entity_type, entities): def _check_collision(self, pos, entity_type, entities_l):
for entity in entities: for entity in entities_l:
if isinstance(entity, entity_type): if isinstance(entity, entity_type):
if entity.shape != 'circle': if isinstance(entity, entities.Void):
raise Exception('Can only raycast circular entities!') hit = 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height
sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \ if hit:
+ (pos[1]-entity.pos[1]*self.env.height)**2 print(pos)
if sq_dist < entity.radius**2: return hit
return True else:
if entity.shape != 'circle':
raise Exception('Can only raycast circular entities!')
sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \
+ (pos[1]-entity.pos[1]*self.env.height)**2
if sq_dist < entity.radius**2:
return True
return False return False
def _get_possible_entities(self): def _get_possible_entities(self):
entities = [] entities_l = []
if entities.Void in self.chans:
entities_l.append(entities.Void(self.env))
for entity in self.env.entities: for entity in self.env.entities:
sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \ sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \
+ ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2 + ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2
if sq_dist <= (entity.radius + self.env.agent.radius + self.ray_len)**2: if sq_dist <= (entity.radius + self.env.agent.radius + self.ray_len)**2:
entities.append(entity) # cannot use yield here! entities_l.append(entity) # cannot use yield here!
return entities return entities_l
def get_observation(self): def get_observation(self):
entities = self._get_possible_entities() entities = self._get_possible_entities()