Added entities.Void for RayTracing detection of Border
This commit is contained in:
parent
253ce6495e
commit
1c81c6bae9
@ -193,3 +193,9 @@ class TimeoutReward(OnceReward):
|
|||||||
self.env.new_abs_reward += self.reward
|
self.env.new_abs_reward += self.reward
|
||||||
self.set_avaible(False)
|
self.set_avaible(False)
|
||||||
self.env.timers.append((self.timeout, self.set_avaible, True))
|
self.env.timers.append((self.timeout, self.set_avaible, True))
|
||||||
|
|
||||||
|
|
||||||
|
class Void():
|
||||||
|
def __init__(self, env):
|
||||||
|
self.col = (50, 50, 50)
|
||||||
|
pass
|
||||||
|
@ -8,6 +8,7 @@ from observables import Observable, CnnObservable
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
env = ColumbusTest3_1()
|
env = ColumbusTest3_1()
|
||||||
|
env = ColumbusTestRay(hide_map=True)
|
||||||
env.start_pos = [0.6, 0.3]
|
env.start_pos = [0.6, 0.3]
|
||||||
playEnv(env)
|
playEnv(env)
|
||||||
env.close()
|
env.close()
|
||||||
|
@ -82,7 +82,7 @@ def _clip(num, lower, upper):
|
|||||||
|
|
||||||
|
|
||||||
class RayObservable(Observable):
|
class RayObservable(Observable):
|
||||||
def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward], ray_len=256):
|
def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward, entities.Void], ray_len=256):
|
||||||
super(RayObservable, self).__init__()
|
super(RayObservable, self).__init__()
|
||||||
self.num_rays = num_rays
|
self.num_rays = num_rays
|
||||||
self.chans = chans
|
self.chans = chans
|
||||||
@ -100,9 +100,15 @@ class RayObservable(Observable):
|
|||||||
rad = 2*math.pi/self.num_rays*i
|
rad = 2*math.pi/self.num_rays*i
|
||||||
yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad)
|
yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad)
|
||||||
|
|
||||||
def _check_collision(self, pos, entity_type, entities):
|
def _check_collision(self, pos, entity_type, entities_l):
|
||||||
for entity in entities:
|
for entity in entities_l:
|
||||||
if isinstance(entity, entity_type):
|
if isinstance(entity, entity_type):
|
||||||
|
if isinstance(entity, entities.Void):
|
||||||
|
hit = 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height
|
||||||
|
if hit:
|
||||||
|
print(pos)
|
||||||
|
return hit
|
||||||
|
else:
|
||||||
if entity.shape != 'circle':
|
if entity.shape != 'circle':
|
||||||
raise Exception('Can only raycast circular entities!')
|
raise Exception('Can only raycast circular entities!')
|
||||||
sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \
|
sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \
|
||||||
@ -112,13 +118,15 @@ class RayObservable(Observable):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_possible_entities(self):
|
def _get_possible_entities(self):
|
||||||
entities = []
|
entities_l = []
|
||||||
|
if entities.Void in self.chans:
|
||||||
|
entities_l.append(entities.Void(self.env))
|
||||||
for entity in self.env.entities:
|
for entity in self.env.entities:
|
||||||
sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \
|
sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \
|
||||||
+ ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2
|
+ ((self.env.agent.pos[1]-entity.pos[1])*self.env.height) ** 2
|
||||||
if sq_dist <= (entity.radius + self.env.agent.radius + self.ray_len)**2:
|
if sq_dist <= (entity.radius + self.env.agent.radius + self.ray_len)**2:
|
||||||
entities.append(entity) # cannot use yield here!
|
entities_l.append(entity) # cannot use yield here!
|
||||||
return entities
|
return entities_l
|
||||||
|
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
entities = self._get_possible_entities()
|
entities = self._get_possible_entities()
|
||||||
|
Loading…
Reference in New Issue
Block a user