Bug fixes and minor additions to observables
This commit is contained in:
parent
1eb86bef06
commit
4bfa15b362
@ -57,6 +57,10 @@ class CnnObservable(Observable):
|
|||||||
rect = pygame.Rect(cx, cy, cw, ch)
|
rect = pygame.Rect(cx, cy, cw, ch)
|
||||||
snap = self.env.surface.subsurface(rect)
|
snap = self.env.surface.subsurface(rect)
|
||||||
self.snap = pygame.Surface((self.in_width, self.in_height))
|
self.snap = pygame.Surface((self.in_width, self.in_height))
|
||||||
|
if self.env.void_barrier:
|
||||||
|
col = (223, 0, 0)
|
||||||
|
else:
|
||||||
|
col = (50, 50, 50)
|
||||||
pygame.draw.rect(self.snap, (50, 50, 50),
|
pygame.draw.rect(self.snap, (50, 50, 50),
|
||||||
pygame.Rect(0, 0, self.in_width, self.in_height))
|
pygame.Rect(0, 0, self.in_width, self.in_height))
|
||||||
self.snap.blit(snap, (cx - x, cy - y))
|
self.snap.blit(snap, (cx - x, cy - y))
|
||||||
@ -82,18 +86,19 @@ def _clip(num, lower, upper):
|
|||||||
|
|
||||||
|
|
||||||
class RayObservable(Observable):
|
class RayObservable(Observable):
|
||||||
def __init__(self, num_rays=24, chans=[entities.Enemy, entities.Reward, entities.Void], ray_len=256):
|
def __init__(self, num_rays=16, chans=[entities.Enemy, entities.Reward], ray_len=256, num_steps=64, include_rand=False):
|
||||||
super(RayObservable, self).__init__()
|
super(RayObservable, self).__init__()
|
||||||
self.num_rays = num_rays
|
self.num_rays = num_rays
|
||||||
self.chans = chans
|
self.chans = chans
|
||||||
self.num_chans = len(chans)
|
self.num_chans = len(chans)
|
||||||
self.ray_len = ray_len
|
self.ray_len = ray_len
|
||||||
self.num_steps = 32 # max = 255
|
self.num_steps = num_steps # max = 255
|
||||||
self.occlusion = True # previous channels block view onto later channels
|
self.occlusion = True # previous channels block view onto later channels
|
||||||
|
self.include_rand = include_rand
|
||||||
|
|
||||||
def get_observation_space(self):
|
def get_observation_space(self):
|
||||||
return spaces.Box(low=0, high=self.num_steps,
|
return spaces.Box(low=0, high=self.num_steps,
|
||||||
shape=(self.num_rays, self.num_chans), dtype=np.uint8)
|
shape=(self.num_rays+self.include_rand, self.num_chans), dtype=np.uint8)
|
||||||
|
|
||||||
def _get_ray_heads(self):
|
def _get_ray_heads(self):
|
||||||
for i in range(self.num_rays):
|
for i in range(self.num_rays):
|
||||||
@ -102,12 +107,10 @@ class RayObservable(Observable):
|
|||||||
|
|
||||||
def _check_collision(self, pos, entity_type, entities_l):
|
def _check_collision(self, pos, entity_type, entities_l):
|
||||||
for entity in entities_l:
|
for entity in entities_l:
|
||||||
if isinstance(entity, entity_type):
|
if isinstance(entity, entity_type) or (self.env.void_barrier and isinstance(entity, entities.Void) and entity_type == entities.Enemy):
|
||||||
if isinstance(entity, entities.Void):
|
if isinstance(entity, entities.Void):
|
||||||
hit = 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height
|
if 0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[0] >= self.env.height:
|
||||||
if hit:
|
return True
|
||||||
print(pos)
|
|
||||||
return hit
|
|
||||||
else:
|
else:
|
||||||
if entity.shape != 'circle':
|
if entity.shape != 'circle':
|
||||||
raise Exception('Can only raycast circular entities!')
|
raise Exception('Can only raycast circular entities!')
|
||||||
@ -119,7 +122,7 @@ class RayObservable(Observable):
|
|||||||
|
|
||||||
def _get_possible_entities(self):
|
def _get_possible_entities(self):
|
||||||
entities_l = []
|
entities_l = []
|
||||||
if entities.Void in self.chans:
|
if entities.Void in self.chans or self.env.void_barrier:
|
||||||
entities_l.append(entities.Void(self.env))
|
entities_l.append(entities.Void(self.env))
|
||||||
for entity in self.env.entities:
|
for entity in self.env.entities:
|
||||||
sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \
|
sq_dist = ((self.env.agent.pos[0]-entity.pos[0])*self.env.width) ** 2 \
|
||||||
@ -130,7 +133,10 @@ class RayObservable(Observable):
|
|||||||
|
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
entities = self._get_possible_entities()
|
entities = self._get_possible_entities()
|
||||||
self.rays = np.zeros((self.num_rays, self.num_chans))
|
self.rays = np.zeros((self.num_rays+self.include_rand, self.num_chans))
|
||||||
|
if self.include_rand:
|
||||||
|
for c in range(self.num_chans):
|
||||||
|
self.rays[-1, c] = self.env.random()
|
||||||
for r, (hx, hy) in enumerate(self._get_ray_heads()):
|
for r, (hx, hy) in enumerate(self._get_ray_heads()):
|
||||||
occ_dist = self.num_steps
|
occ_dist = self.num_steps
|
||||||
for c, entity_type in enumerate(self.chans):
|
for c, entity_type in enumerate(self.chans):
|
||||||
@ -162,7 +168,7 @@ class RayObservable(Observable):
|
|||||||
pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0)
|
pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0)
|
||||||
|
|
||||||
|
|
||||||
def StateObservable(Observable):
|
class StateObservable(Observable):
|
||||||
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True):
|
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True):
|
||||||
super(StateObservable, self).__init__()
|
super(StateObservable, self).__init__()
|
||||||
self._entities = None
|
self._entities = None
|
||||||
@ -181,6 +187,7 @@ def StateObservable(Observable):
|
|||||||
def entities(self):
|
def entities(self):
|
||||||
if self._entities:
|
if self._entities:
|
||||||
return self._entities
|
return self._entities
|
||||||
|
self.env.setup()
|
||||||
self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
self.rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
||||||
self.enemysWhitelist = self.enemysWhitelist or self.env.entities
|
self.enemysWhitelist = self.enemysWhitelist or self.env.entities
|
||||||
self._entities = []
|
self._entities = []
|
||||||
@ -190,12 +197,12 @@ def StateObservable(Observable):
|
|||||||
for entity in self.rewardsWhitelist:
|
for entity in self.rewardsWhitelist:
|
||||||
if isinstance(entity, entities.Reward):
|
if isinstance(entity, entities.Reward):
|
||||||
self._entities.append(entity)
|
self._entities.append(entity)
|
||||||
if self.coordEnemys:
|
if self.coordsEnemys:
|
||||||
for entity in self.enemysWhitelist:
|
for entity in self.enemysWhitelist:
|
||||||
if isinstance(entity, entities.Enemy):
|
if isinstance(entity, entities.Enemy):
|
||||||
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
||||||
self._entities.append(entity)
|
self._entities.append(entity)
|
||||||
if self.rewardsTimeout:
|
if self.rewardsTimeouts:
|
||||||
for entity in self.enemysWhitelist:
|
for entity in self.enemysWhitelist:
|
||||||
if isinstance(entity, entities.TimeoutReward):
|
if isinstance(entity, entities.TimeoutReward):
|
||||||
self._timeoutEntities.append(entity)
|
self._timeoutEntities.append(entity)
|
||||||
@ -203,7 +210,7 @@ def StateObservable(Observable):
|
|||||||
|
|
||||||
def get_observation_space(self):
|
def get_observation_space(self):
|
||||||
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1,
|
||||||
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent), dtype=np.float32)
|
shape=(len(self.entities)*2+len(self._timeoutEntities) + self.speedAgent,), dtype=np.float32)
|
||||||
|
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
obs = []
|
obs = []
|
||||||
|
Loading…
Reference in New Issue
Block a user