Added entities.Void for RayTracing detection of Border

This commit is contained in:
Dominik Moritz Roth 2022-06-19 23:14:39 +02:00
parent 253ce6495e
commit 1c81c6bae9
3 changed files with 27 additions and 12 deletions

View File

@ -193,3 +193,9 @@ class TimeoutReward(OnceReward):
self.env.new_abs_reward += self.reward
self.set_avaible(False)
self.env.timers.append((self.timeout, self.set_avaible, True))
class Void():
def __init__(self, env):
self.col = (50, 50, 50)
pass

View File

@ -8,6 +8,7 @@ from observables import Observable, CnnObservable
def main():
env = ColumbusTest3_1()
env = ColumbusTestRay(hide_map=True)
env.start_pos = [0.6, 0.3]
playEnv(env)
env.close()

View File

@ -82,7 +82,7 @@ def _clip(num, lower, upper):
class RayObservable(Observable):
def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward], ray_len=256):
def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward, entities.Void], ray_len=256):
super(RayObservable, self).__init__()
self.num_rays = num_rays
self.chans = chans
@ -100,9 +100,15 @@ class RayObservable(Observable):
rad = 2*math.pi/self.num_rays*i
yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad)
def _check_collision(self, pos, entity_type, entities):
for entity in entities:
def _check_collision(self, pos, entity_type, entities_l):
for entity in entities_l:
if isinstance(entity, entity_type):
if isinstance(entity, entities.Void):
hit = 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height
if hit:
print(pos)
return hit
else:
if entity.shape != 'circle':
raise Exception('Can only raycast circular entities!')
sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \
@ -112,13 +118,15 @@ class RayObservable(Observable):
return False
def _get_possible_entities(self):
entities = []
entities_l = []
if entities.Void in self.chans:
entities_l.append(entities.Void(self.env))
for entity in self.env.entities:
sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \
+ ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2
if sq_dist <= (entity.radius + self.env.agent.radius + self.ray_len)**2:
entities.append(entity) # cannot use yield here!
return entities
entities_l.append(entity) # cannot use yield here!
return entities_l
def get_observation(self):
entities = self._get_possible_entities()