2022-06-19 15:01:30 +02:00
|
|
|
from gym import spaces
|
|
|
|
import numpy as np
|
|
|
|
import pygame
|
2022-06-19 20:33:45 +02:00
|
|
|
import math
|
|
|
|
from columbus import entities
|
2022-08-17 19:31:15 +02:00
|
|
|
import torch as th
|
2022-06-19 15:01:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
class Observable():
|
|
|
|
def __init__(self):
|
|
|
|
self.obs = None
|
|
|
|
|
2022-06-19 15:59:23 +02:00
|
|
|
def _set_env(self, env):
|
|
|
|
self.env = env
|
|
|
|
|
2022-08-25 13:38:59 +02:00
|
|
|
def get_observation_space(self):
|
2022-06-19 15:01:30 +02:00
|
|
|
print("[!] Using dummyObservable. Env won't output anything")
|
2022-08-20 17:18:03 +02:00
|
|
|
return spaces.Box(low=0, high=1,
|
2022-12-13 19:45:13 +01:00
|
|
|
shape=(1,), dtype=np.float64)
|
2022-06-19 15:01:30 +02:00
|
|
|
|
2022-06-19 20:33:45 +02:00
|
|
|
def get_observation(self):
|
2022-08-20 17:18:03 +02:00
|
|
|
return np.array([0])
|
2022-06-19 20:33:45 +02:00
|
|
|
|
|
|
|
def draw(self):
|
|
|
|
pass
|
|
|
|
|
2022-08-22 18:53:30 +02:00
|
|
|
def reset(self):
|
|
|
|
pass
|
|
|
|
|
2022-06-19 15:01:30 +02:00
|
|
|
|
|
|
|
class CnnObservable(Observable):
|
2022-09-16 11:38:21 +02:00
|
|
|
# Currently broken...
|
2022-06-19 15:01:30 +02:00
|
|
|
def __init__(self, in_width=256, in_height=256, out_width=32, out_height=32, draw_width=128, draw_height=128, smooth_scaling=True):
|
|
|
|
super(CnnObservable, self).__init__()
|
|
|
|
self.in_width = in_width
|
|
|
|
self.in_height = in_height
|
|
|
|
self.out_width = out_width
|
|
|
|
self.out_height = out_height
|
|
|
|
self.draw_width = draw_width
|
|
|
|
self.draw_height = draw_height
|
|
|
|
if smooth_scaling:
|
|
|
|
self.scaler = pygame.transform.smoothscale
|
|
|
|
else:
|
|
|
|
self.scaler = pygame.transform.scale
|
|
|
|
|
|
|
|
def get_observation_space(self):
|
|
|
|
return spaces.Box(low=0, high=255,
|
2022-12-13 19:45:13 +01:00
|
|
|
shape=(self.out_width, self.out_height, 3), dtype=np.float64)
|
2022-06-19 15:01:30 +02:00
|
|
|
|
|
|
|
def get_observation(self):
|
2022-06-19 20:33:45 +02:00
|
|
|
if not self.env._rendered:
|
2022-09-13 22:14:17 +02:00
|
|
|
self.env.render(mode='internal', dont_show=False)
|
2022-06-19 17:20:51 +02:00
|
|
|
self.env._ensure_surface()
|
2022-06-19 15:01:30 +02:00
|
|
|
x, y = self.env.agent.pos[0]*self.env.width - self.in_width / \
|
|
|
|
2, self.env.agent.pos[1]*self.env.height - self.in_height/2
|
|
|
|
w, h = self.in_width, self.in_height
|
|
|
|
cx, cy = _clip(x, 0, self.env.width), _clip(
|
|
|
|
y, 0, self.env.height)
|
|
|
|
cw, ch = _clip(w, 0, self.env.width - cx), _clip(h,
|
|
|
|
0, self.env.height - cy)
|
|
|
|
rect = pygame.Rect(cx, cy, cw, ch)
|
|
|
|
snap = self.env.surface.subsurface(rect)
|
|
|
|
self.snap = pygame.Surface((self.in_width, self.in_height))
|
2022-06-20 23:11:11 +02:00
|
|
|
if self.env.void_barrier:
|
2022-06-22 13:09:01 +02:00
|
|
|
col = (255, 0, 0)
|
2022-06-20 23:11:11 +02:00
|
|
|
else:
|
|
|
|
col = (50, 50, 50)
|
2022-06-22 13:09:01 +02:00
|
|
|
pygame.draw.rect(self.snap, col,
|
2022-06-19 15:01:30 +02:00
|
|
|
pygame.Rect(0, 0, self.in_width, self.in_height))
|
|
|
|
self.snap.blit(snap, (cx - x, cy - y))
|
|
|
|
self.obs = self.scaler(
|
|
|
|
self.snap, (self.out_width, self.out_height))
|
2022-06-19 17:20:51 +02:00
|
|
|
arr = pygame.surfarray.array3d(self.obs)
|
|
|
|
return arr
|
2022-06-19 15:01:30 +02:00
|
|
|
|
|
|
|
def draw(self):
|
|
|
|
if not self.obs:
|
|
|
|
self.get_observation()
|
|
|
|
big = pygame.transform.scale(
|
|
|
|
self.obs, (self.draw_width, self.draw_height))
|
|
|
|
x, y = self.env.width - self.draw_width - 10, 10
|
|
|
|
pygame.draw.rect(self.env.screen, (50, 50, 50),
|
|
|
|
pygame.Rect(x - 1, y - 1, self.draw_width + 2, self.draw_height + 2))
|
|
|
|
self.env.screen.blit(
|
|
|
|
big, (x, y))
|
|
|
|
|
|
|
|
|
|
|
|
def _clip(num, lower, upper):
|
|
|
|
return min(max(num, lower), upper)
|
2022-06-19 20:33:45 +02:00
|
|
|
|
|
|
|
|
|
|
|
class RayObservable(Observable):
|
2022-06-20 23:11:11 +02:00
|
|
|
def __init__(self, num_rays=16, chans=[entities.Enemy, entities.Reward], ray_len=256, num_steps=64, include_rand=False):
|
2022-06-19 20:33:45 +02:00
|
|
|
super(RayObservable, self).__init__()
|
|
|
|
self.num_rays = num_rays
|
|
|
|
self.chans = chans
|
|
|
|
self.num_chans = len(chans)
|
|
|
|
self.ray_len = ray_len
|
2022-06-20 23:11:11 +02:00
|
|
|
self.num_steps = num_steps # max = 255
|
2022-06-19 20:33:45 +02:00
|
|
|
self.occlusion = True # previous channels block view onto later channels
|
2022-06-20 23:11:11 +02:00
|
|
|
self.include_rand = include_rand
|
2022-06-19 20:33:45 +02:00
|
|
|
|
|
|
|
def get_observation_space(self):
|
2022-08-15 17:15:56 +02:00
|
|
|
return spaces.Box(low=0, high=1,
|
2022-06-20 23:11:11 +02:00
|
|
|
shape=(self.num_rays+self.include_rand, self.num_chans), dtype=np.uint8)
|
2022-06-19 20:33:45 +02:00
|
|
|
|
|
|
|
def _get_ray_heads(self):
|
|
|
|
for i in range(self.num_rays):
|
|
|
|
rad = 2*math.pi/self.num_rays*i
|
|
|
|
yield self.ray_len*math.sin(rad), self.ray_len*math.cos(rad)
|
|
|
|
|
2022-06-19 23:14:39 +02:00
|
|
|
def _check_collision(self, pos, entity_type, entities_l):
|
|
|
|
for entity in entities_l:
|
2022-06-20 23:11:11 +02:00
|
|
|
if isinstance(entity, entity_type) or (self.env.void_barrier and isinstance(entity, entities.Void) and entity_type == entities.Enemy):
|
2022-06-19 23:14:39 +02:00
|
|
|
if isinstance(entity, entities.Void):
|
2022-08-17 19:31:15 +02:00
|
|
|
if not self.env.torus_topology and (0 >= pos[0] or pos[0] >= self.env.width or 0 >= pos[1] or pos[1] >= self.env.width):
|
2022-06-20 23:11:11 +02:00
|
|
|
return True
|
2022-06-19 23:14:39 +02:00
|
|
|
else:
|
2022-09-13 22:25:29 +02:00
|
|
|
if entity.shape == 'circle':
|
|
|
|
sq_dist = (pos[0]-entity.pos[0]*self.env.width) ** 2 \
|
|
|
|
+ (pos[1]-entity.pos[1]*self.env.height)**2
|
|
|
|
if sq_dist < entity.radius**2:
|
|
|
|
return True
|
|
|
|
elif entity.shape == 'rect':
|
|
|
|
dot = entities.CircularEntity(self.env)
|
|
|
|
dot.radius = 1
|
|
|
|
dot.pos = pos[0]/self.env.width, pos[1]/self.env.height
|
|
|
|
if sum(dot._get_crash_force_dir(entity)) != 0:
|
|
|
|
return True
|
|
|
|
else:
|
|
|
|
raise Exception(
|
|
|
|
'Can only raycast circular and rectangular entities!')
|
2022-06-19 20:33:45 +02:00
|
|
|
return False
|
|
|
|
|
2022-12-09 11:20:15 +01:00
|
|
|
# Filter out entities, that we sure are out of range
|
|
|
|
# (so we have to do less work for the ray collisions)
|
2022-06-19 21:47:35 +02:00
|
|
|
def _get_possible_entities(self):
|
2022-06-19 23:14:39 +02:00
|
|
|
entities_l = []
|
2022-06-20 23:11:11 +02:00
|
|
|
if entities.Void in self.chans or self.env.void_barrier:
|
2022-06-19 23:14:39 +02:00
|
|
|
entities_l.append(entities.Void(self.env))
|
2022-06-19 21:47:35 +02:00
|
|
|
for entity in self.env.entities:
|
2022-09-13 22:25:29 +02:00
|
|
|
if entity.shape == 'rect':
|
2022-12-08 20:27:36 +01:00
|
|
|
x, y = entity.pos[0]+entity.width/self.env.width / \
|
|
|
|
2, entity.pos[1]+entity.height/self.env.height/2
|
2022-09-13 22:25:29 +02:00
|
|
|
radius = (entity.width/2 + entity.height/2)*1.0
|
|
|
|
elif entity.shape == 'circle':
|
2022-12-08 20:27:36 +01:00
|
|
|
x, y = entity.pos[0], entity.pos[1]
|
2022-09-13 22:25:29 +02:00
|
|
|
radius = entity.radius
|
|
|
|
else:
|
|
|
|
raise Exception(
|
|
|
|
'Can only raycast circular and rectangular entities!')
|
2022-12-08 20:27:36 +01:00
|
|
|
sq_dist = ((self.env.agent.pos[0]-x)*self.env.width) ** 2 \
|
|
|
|
+ ((self.env.agent.pos[1]-y)*self.env.height) ** 2
|
2022-12-09 17:05:06 +01:00
|
|
|
if sq_dist <= (radius + self.env.agent.getQuasiRadius() + self.ray_len)**2:
|
2022-06-19 23:14:39 +02:00
|
|
|
entities_l.append(entity) # cannot use yield here!
|
|
|
|
return entities_l
|
2022-06-19 21:47:35 +02:00
|
|
|
|
2022-12-09 11:20:15 +01:00
|
|
|
# Ugly, inefficient ray casting
|
|
|
|
# Oh well, it works...
|
2022-06-19 20:33:45 +02:00
|
|
|
def get_observation(self):
|
2022-06-19 21:47:35 +02:00
|
|
|
entities = self._get_possible_entities()
|
2022-06-20 23:11:11 +02:00
|
|
|
self.rays = np.zeros((self.num_rays+self.include_rand, self.num_chans))
|
|
|
|
if self.include_rand:
|
|
|
|
for c in range(self.num_chans):
|
2022-08-28 21:35:13 +02:00
|
|
|
self.rays[-1, c] = np.random.rand()
|
2022-06-19 20:33:45 +02:00
|
|
|
for r, (hx, hy) in enumerate(self._get_ray_heads()):
|
|
|
|
occ_dist = self.num_steps
|
|
|
|
for c, entity_type in enumerate(self.chans):
|
|
|
|
for s in range(self.num_steps):
|
|
|
|
if s > occ_dist:
|
|
|
|
break
|
2022-06-19 22:58:30 +02:00
|
|
|
sx, sy = (s+1)*hx/self.num_steps, (s+1)*hy/self.num_steps
|
2022-06-19 20:33:45 +02:00
|
|
|
rx, ry = sx + \
|
|
|
|
self.env.agent.pos[0]*self.env.width, sy + \
|
|
|
|
self.env.agent.pos[1]*self.env.height
|
2022-08-17 19:31:15 +02:00
|
|
|
if self.env.torus_topology:
|
|
|
|
rx, ry = rx % self.env.width, ry % self.env.height
|
2022-06-19 21:47:35 +02:00
|
|
|
if self._check_collision((rx, ry), entity_type, entities):
|
2022-08-15 17:15:56 +02:00
|
|
|
self.rays[r, c] = (self.num_steps-s)/self.num_steps
|
2022-06-19 20:33:45 +02:00
|
|
|
if self.occlusion:
|
|
|
|
occ_dist = s
|
|
|
|
break
|
|
|
|
return self.rays
|
|
|
|
|
|
|
|
def draw(self):
|
|
|
|
for c, entity_type in enumerate(self.chans):
|
|
|
|
for r, (hx, hy) in enumerate(self._get_ray_heads()):
|
2022-08-15 17:15:56 +02:00
|
|
|
s = self.num_steps - self.rays[r, c]*self.num_steps
|
2022-06-19 22:58:30 +02:00
|
|
|
sx, sy = (s+1)*hx/self.num_steps, (s+1)*hy/self.num_steps
|
2022-06-19 20:33:45 +02:00
|
|
|
rx, ry = sx + \
|
|
|
|
self.env.agent.pos[0]*self.env.width, sy + \
|
|
|
|
self.env.agent.pos[1]*self.env.height
|
2022-08-17 19:31:15 +02:00
|
|
|
if self.env.torus_topology:
|
|
|
|
rx, ry = rx % self.env.width, ry % self.env.height
|
2022-06-19 20:33:45 +02:00
|
|
|
# TODO: How stupid do I want to code?
|
2022-07-02 14:42:56 +02:00
|
|
|
# This instanciates an Object for every Ray-hit,
|
|
|
|
# just to get the color for the visual.
|
|
|
|
# But since this Code will not be executed during training,
|
|
|
|
# I don't think fixing this is an priority...
|
2022-06-19 20:33:45 +02:00
|
|
|
col = entity_type(self.env).col
|
|
|
|
col = int(col[0]/2), int(col[1]/2), int(col[2]/2)
|
|
|
|
pygame.draw.circle(self.env.screen, col, (rx, ry), 3, width=0)
|
2022-06-19 22:46:42 +02:00
|
|
|
|
|
|
|
|
2022-06-20 23:11:11 +02:00
|
|
|
class StateObservable(Observable):
|
2022-09-16 11:38:21 +02:00
|
|
|
# Whitelists probably don't work...
|
2022-06-21 21:38:18 +02:00
|
|
|
def __init__(self, coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=True, include_rand=True):
|
2022-06-19 22:46:42 +02:00
|
|
|
super(StateObservable, self).__init__()
|
|
|
|
self._entities = None
|
|
|
|
self._timeoutEntities = []
|
|
|
|
self.coordsAgent = coordsAgent
|
|
|
|
self.speedAgent = speedAgent
|
|
|
|
self.coordsRelativeToAgent = coordsRelativeToAgent
|
|
|
|
self.coordRewards = coordsRewards
|
|
|
|
self.rewardsWhitelist = rewardsWhitelist
|
|
|
|
self.coordsEnemys = coordsEnemys
|
|
|
|
self.enemysWhitelist = enemysWhitelist
|
|
|
|
self.enemysNoBarriers = enemysNoBarriers
|
|
|
|
self.rewardsTimeouts = rewardsTimeouts
|
2022-06-21 21:38:18 +02:00
|
|
|
self.include_rand = include_rand
|
2022-06-19 22:46:42 +02:00
|
|
|
|
|
|
|
@property
|
|
|
|
def entities(self):
|
2022-06-21 22:29:50 +02:00
|
|
|
if not self._entities == None:
|
2022-06-19 22:46:42 +02:00
|
|
|
return self._entities
|
2022-06-21 22:29:50 +02:00
|
|
|
rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
|
|
|
enemysWhitelist = self.enemysWhitelist or self.env.entities
|
2022-06-19 22:46:42 +02:00
|
|
|
self._entities = []
|
|
|
|
if self.coordsAgent:
|
|
|
|
self._entities.append(self.env.agent)
|
|
|
|
if self.coordRewards:
|
2022-06-21 22:29:50 +02:00
|
|
|
for entity in rewardsWhitelist:
|
2022-06-19 22:46:42 +02:00
|
|
|
if isinstance(entity, entities.Reward):
|
|
|
|
self._entities.append(entity)
|
2022-06-20 23:11:11 +02:00
|
|
|
if self.coordsEnemys:
|
2022-06-21 22:29:50 +02:00
|
|
|
for entity in enemysWhitelist:
|
2022-06-19 22:46:42 +02:00
|
|
|
if isinstance(entity, entities.Enemy):
|
|
|
|
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
|
|
|
self._entities.append(entity)
|
2022-06-20 23:11:11 +02:00
|
|
|
if self.rewardsTimeouts:
|
2022-06-21 22:29:50 +02:00
|
|
|
for entity in enemysWhitelist:
|
2022-06-19 22:46:42 +02:00
|
|
|
if isinstance(entity, entities.TimeoutReward):
|
|
|
|
self._timeoutEntities.append(entity)
|
|
|
|
return self._entities
|
|
|
|
|
2022-08-22 18:53:30 +02:00
|
|
|
def reset(self):
|
|
|
|
self._entities = None
|
|
|
|
|
2022-06-19 22:46:42 +02:00
|
|
|
def get_observation_space(self):
|
2022-08-25 13:38:59 +02:00
|
|
|
self.reset()
|
2022-06-21 22:29:50 +02:00
|
|
|
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
2022-08-22 18:08:01 +02:00
|
|
|
self.speedAgent*2 + self.include_rand
|
2022-12-13 19:50:57 +01:00
|
|
|
return spaces.Box(low=0-1*(self.coordsRelativeToAgent or self.speedAgent), high=1,
|
2022-12-13 19:45:13 +01:00
|
|
|
shape=(num,), dtype=np.float64)
|
2022-06-19 22:46:42 +02:00
|
|
|
|
|
|
|
def get_observation(self):
|
|
|
|
obs = []
|
|
|
|
if self.coordsRelativeToAgent:
|
|
|
|
for entity in self.entities:
|
|
|
|
if not isinstance(entity, entities.Agent):
|
|
|
|
obs.append(entity.pos[0] - self.env.agent.pos[0])
|
|
|
|
obs.append(entity.pos[1] - self.env.agent.pos[1])
|
|
|
|
else:
|
|
|
|
obs.append(entity.pos[0])
|
|
|
|
obs.append(entity.pos[1])
|
|
|
|
else:
|
|
|
|
for entity in self.entities:
|
|
|
|
obs.append(entity.pos[0])
|
|
|
|
obs.append(entity.pos[1])
|
|
|
|
|
|
|
|
for entity in self._timeoutEntities:
|
|
|
|
obs.append(entity.active)
|
|
|
|
if self.speedAgent:
|
2022-08-22 18:08:01 +02:00
|
|
|
obs.append(self.env.agent.speed[0])
|
|
|
|
obs.append(self.env.agent.speed[1])
|
2022-06-21 21:38:18 +02:00
|
|
|
if self.include_rand:
|
2022-08-28 21:35:13 +02:00
|
|
|
obs.append(np.random.rand())
|
2022-06-19 22:46:42 +02:00
|
|
|
self.obs = obs
|
|
|
|
return np.array(obs)
|
|
|
|
|
|
|
|
def draw(self):
|
2022-06-29 12:42:49 +02:00
|
|
|
ofs = (0 + self.env.height/2*self.coordsRelativeToAgent,
|
|
|
|
0 + self.env.width/2*self.coordsRelativeToAgent)
|
|
|
|
if self.coordsRelativeToAgent:
|
|
|
|
pygame.draw.circle(self.env.screen, self.env.agent.col,
|
|
|
|
(0, self.env.height/2), 3, width=0)
|
|
|
|
pygame.draw.circle(self.env.screen, self.env.agent.col,
|
|
|
|
(self.env.width/2, 0), 3, width=0)
|
2022-08-22 18:17:59 +02:00
|
|
|
for i in range(int(len(self.obs)/2) - self.speedAgent):
|
2022-06-21 22:29:50 +02:00
|
|
|
x, y = self.obs[i*2], self.obs[i*2+1]
|
|
|
|
col = self.entities[i].col
|
|
|
|
pygame.draw.circle(self.env.screen, col,
|
2022-06-29 12:42:49 +02:00
|
|
|
(0, y*self.env.height+ofs[0]), 1, width=0)
|
2022-06-21 22:29:50 +02:00
|
|
|
pygame.draw.circle(self.env.screen, col,
|
2022-06-29 12:42:49 +02:00
|
|
|
(x*self.env.width+ofs[1], 0), 1, width=0)
|
2022-08-07 18:03:27 +02:00
|
|
|
|
|
|
|
|
2022-08-20 21:32:34 +02:00
|
|
|
class CompassObservable(Observable):
|
2022-09-16 11:38:21 +02:00
|
|
|
# Usefull for navigation close to an reward.
|
|
|
|
# Works like the StateObservable, but we assign a bigger range of possible input values to those, that are close to zero.
|
|
|
|
# I found that Agents without such an Observable often moved close to a reward and then just jiggled arround, adding a CompassObservable fixes this
|
2022-08-20 21:32:34 +02:00
|
|
|
def __init__(self, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=False, enemysWhitelist=None, enemysNoBarriers=True):
|
|
|
|
super().__init__()
|
|
|
|
self._entities = None
|
|
|
|
self._timeoutEntities = []
|
|
|
|
self.coordRewards = coordsRewards
|
|
|
|
self.rewardsWhitelist = rewardsWhitelist
|
|
|
|
self.coordsEnemys = coordsEnemys
|
|
|
|
self.enemysWhitelist = enemysWhitelist
|
|
|
|
self.enemysNoBarriers = enemysNoBarriers
|
|
|
|
|
|
|
|
@property
|
|
|
|
def entities(self):
|
|
|
|
if not self._entities == None:
|
|
|
|
return self._entities
|
|
|
|
rewardsWhitelist = self.rewardsWhitelist or self.env.entities
|
|
|
|
enemysWhitelist = self.enemysWhitelist or self.env.entities
|
|
|
|
self._entities = []
|
|
|
|
if self.coordRewards:
|
|
|
|
for entity in rewardsWhitelist:
|
|
|
|
if isinstance(entity, entities.Reward):
|
|
|
|
self._entities.append(entity)
|
|
|
|
if self.coordsEnemys:
|
|
|
|
for entity in enemysWhitelist:
|
|
|
|
if isinstance(entity, entities.Enemy):
|
|
|
|
if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier):
|
|
|
|
self._entities.append(entity)
|
|
|
|
return self._entities
|
|
|
|
|
|
|
|
def get_observation_space(self):
|
2022-08-27 21:07:31 +02:00
|
|
|
self.reset()
|
2022-08-20 21:32:34 +02:00
|
|
|
num = len(self.entities)*2
|
|
|
|
return spaces.Box(low=-1, high=1,
|
2022-12-13 19:45:13 +01:00
|
|
|
shape=(num,), dtype=np.float64)
|
2022-08-20 21:32:34 +02:00
|
|
|
|
2022-08-27 21:07:31 +02:00
|
|
|
def reset(self):
|
|
|
|
self._entities = None
|
|
|
|
|
2022-08-20 21:32:34 +02:00
|
|
|
def get_observation(self):
|
|
|
|
obs = []
|
|
|
|
for entity in self.entities:
|
|
|
|
dx, dy = entity.pos[0] - \
|
|
|
|
self.env.agent.pos[0], entity.pos[1] - self.env.agent.pos[1]
|
|
|
|
l = math.sqrt(dx**2 + dy**2)*2
|
|
|
|
x, y = math.tanh(dx/l), math.tanh(dy/l)
|
|
|
|
obs.append(x)
|
|
|
|
obs.append(y)
|
|
|
|
|
|
|
|
self.obs = obs
|
|
|
|
return np.array(obs)
|
|
|
|
|
|
|
|
def draw(self):
|
|
|
|
ofs = (0 + self.env.height/2,
|
|
|
|
0 + self.env.width/2)
|
|
|
|
if True:
|
|
|
|
pygame.draw.circle(self.env.screen, self.env.agent.col,
|
|
|
|
(0, self.env.height/2), 3, width=0)
|
|
|
|
pygame.draw.circle(self.env.screen, self.env.agent.col,
|
|
|
|
(self.env.width/2, 0), 3, width=0)
|
|
|
|
for i in range(int(len(self.obs)/2)):
|
|
|
|
x, y = self.obs[i*2], self.obs[i*2+1]
|
|
|
|
col = self.entities[i].col
|
|
|
|
pygame.draw.circle(self.env.screen, col,
|
|
|
|
(0, y*self.env.height+ofs[0]), 1, width=0)
|
|
|
|
pygame.draw.circle(self.env.screen, col,
|
|
|
|
(x*self.env.width+ofs[1], 0), 1, width=0)
|
|
|
|
|
|
|
|
|
2022-08-07 18:03:27 +02:00
|
|
|
class CompositionalObservable(Observable):
|
2022-09-16 11:38:21 +02:00
|
|
|
# Used whenever you want to attach multiple Observables to an Env.
|
|
|
|
# We currently flatten the outputs of all attached Observables, so using a CNN though an CompositionalObservable would lead to problems.
|
2022-08-07 18:03:27 +02:00
|
|
|
def __init__(self, observables):
|
|
|
|
super().__init__()
|
|
|
|
self.observables = observables
|
|
|
|
|
|
|
|
def get_observation_space(self):
|
|
|
|
num = 0
|
|
|
|
for i, obs in enumerate(self.observables):
|
|
|
|
space = obs.get_observation_space()
|
|
|
|
num += math.prod(space.shape)
|
2022-08-22 15:55:16 +02:00
|
|
|
if not i:
|
|
|
|
low = space.low.reshape((-1))
|
|
|
|
high = space.high.reshape((-1))
|
|
|
|
else:
|
|
|
|
low = np.hstack((low, space.low.reshape((-1))))
|
|
|
|
high = np.hstack((high, space.high.reshape((-1))))
|
2022-08-07 18:03:27 +02:00
|
|
|
return spaces.Box(low=low, high=high,
|
2022-12-13 19:45:13 +01:00
|
|
|
shape=(num,), dtype=np.float64)
|
2022-08-07 18:03:27 +02:00
|
|
|
|
|
|
|
def get_observation(self):
|
2022-08-22 17:23:12 +02:00
|
|
|
o = [obs.get_observation().reshape((-1))
|
2022-08-17 19:31:15 +02:00
|
|
|
for obs in self.observables]
|
2022-08-22 17:23:12 +02:00
|
|
|
o = np.hstack(o)
|
2022-08-07 18:03:27 +02:00
|
|
|
return o
|
|
|
|
|
|
|
|
def draw(self):
|
|
|
|
for obs in self.observables:
|
|
|
|
obs.draw()
|
|
|
|
|
|
|
|
def _set_env(self, env):
|
|
|
|
for obs in self.observables:
|
|
|
|
obs._set_env(env)
|
2022-08-22 18:53:30 +02:00
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
for obs in self.observables:
|
|
|
|
obs.reset()
|