2022-06-20 23:10:14 +02:00
|
|
|
from gym.envs.registration import register
|
2022-06-19 15:01:30 +02:00
|
|
|
import gym
|
|
|
|
from gym import spaces
|
|
|
|
import numpy as np
|
|
|
|
import pygame
|
|
|
|
import random as random_dont_use
|
2022-06-21 15:13:59 +02:00
|
|
|
from os import urandom
|
2022-06-19 15:01:30 +02:00
|
|
|
import math
|
2022-07-16 23:25:48 +02:00
|
|
|
import torch as th
|
2022-06-19 15:01:30 +02:00
|
|
|
|
2022-12-09 11:20:15 +01:00
|
|
|
from columbus import entities, observables
|
|
|
|
from columbus.utils import soft_int, parseObs
|
2022-08-20 21:32:34 +02:00
|
|
|
|
|
|
|
|
2022-06-19 15:10:58 +02:00
|
|
|
class ColumbusEnv(gym.Env):
|
2022-12-13 20:13:53 +01:00
|
|
|
metadata = {'render.modes': ['human'], 'render_modes': [
|
|
|
|
'human', 'non-human'], 'render_fps': 60}
|
2022-06-19 15:01:30 +02:00
|
|
|
|
2022-12-09 17:05:06 +01:00
|
|
|
def __init__(self, observable=observables.Observable(), fps=60, env_seed=3.1, master_seed=None, start_pos=(0.5, 0.5), start_score=0, speed_fac=0.01, acc_fac=0.04, die_on_zero=False, return_on_score=-1, reward_mult=1, agent_drag=0, controll_type='SPEED', aux_reward_max=1, aux_penalty_max=0, aux_reward_discretize=0, void_is_type_barrier=True, void_damage=1, torus_topology=False, default_collision_elasticity=1, terminate_on_reward=False, agent_draw_path=False, clear_path_on_reset=True, max_steps=-1, value_color_mapper='tanh', width=720, height=720, agent_attrs={}, agent_cls=entities.Agent, exception_for_unsupported_collision=True):
|
2022-06-19 15:32:12 +02:00
|
|
|
super(ColumbusEnv, self).__init__()
|
2022-06-19 15:01:30 +02:00
|
|
|
self.action_space = spaces.Box(
|
2022-06-29 12:41:52 +02:00
|
|
|
low=-1, high=1, shape=(2,), dtype=np.float32)
|
2022-08-17 19:31:15 +02:00
|
|
|
if not isinstance(observable, observables.Observable):
|
|
|
|
observable = parseObs(observable)
|
2022-06-19 15:01:30 +02:00
|
|
|
observable._set_env(self)
|
|
|
|
self.observable = observable
|
2022-12-06 19:11:12 +01:00
|
|
|
self.title = 'Columbus Env'
|
2022-06-19 15:01:30 +02:00
|
|
|
self.fps = fps
|
|
|
|
self.env_seed = env_seed
|
|
|
|
self.joystick_offset = (10, 10)
|
|
|
|
self.surface = None
|
|
|
|
self.screen = None
|
2022-12-06 19:11:12 +01:00
|
|
|
self.width = width
|
|
|
|
self.height = height
|
2022-06-19 17:20:51 +02:00
|
|
|
self.visible = False
|
2022-08-17 19:31:15 +02:00
|
|
|
self.start_pos = start_pos
|
|
|
|
self.speed_fac = speed_fac/fps*60
|
|
|
|
self.acc_fac = acc_fac/fps*60
|
|
|
|
self.die_on_zero = die_on_zero # return (/die) when score hist zero
|
|
|
|
self.return_on_score = return_on_score # -1 = Never
|
|
|
|
self.reward_mult = reward_mult
|
|
|
|
self.start_score = start_score
|
2022-08-15 17:16:18 +02:00
|
|
|
# 0.01 is a good value, drag with the environment (air / ground)
|
2022-08-17 19:31:15 +02:00
|
|
|
self.agent_drag = agent_drag
|
|
|
|
assert controll_type == 'SPEED' or controll_type == 'ACC'
|
2022-06-19 15:01:30 +02:00
|
|
|
self.limit_inp_to_unit_circle = True
|
2022-08-17 19:31:15 +02:00
|
|
|
self.controll_type = controll_type # one of SPEED, ACC
|
2022-08-15 17:16:18 +02:00
|
|
|
self.aux_reward_max = aux_reward_max # 0 = off
|
|
|
|
self.aux_penalty_max = aux_penalty_max # 0 = off
|
2022-08-17 19:31:15 +02:00
|
|
|
self.aux_reward_discretize = aux_reward_discretize
|
2022-08-15 17:16:18 +02:00
|
|
|
# 0 = dont discretize; how many steps (along diagonal)
|
2022-10-25 14:43:29 +02:00
|
|
|
self.penalty_from_edges = True # Don't change, only here to allow legacy behavior
|
2022-06-19 15:01:30 +02:00
|
|
|
self.draw_observable = True
|
|
|
|
self.draw_joystick = True
|
2022-06-19 20:32:37 +02:00
|
|
|
self.draw_entities = True
|
2022-07-17 01:00:55 +02:00
|
|
|
self.draw_confidence_ellipse = True
|
2022-08-15 17:16:18 +02:00
|
|
|
# If the Void should be of type Barrier (else it is just of type Void and Entity)
|
2022-08-17 19:31:15 +02:00
|
|
|
self.void_barrier = void_is_type_barrier
|
2022-08-16 11:12:58 +02:00
|
|
|
self.void_damage = void_damage
|
2022-08-17 19:31:15 +02:00
|
|
|
self.torus_topology = torus_topology
|
2022-08-27 11:35:57 +02:00
|
|
|
self.default_collision_elasticity = default_collision_elasticity
|
2022-09-20 21:57:41 +02:00
|
|
|
self.terminate_on_reward = terminate_on_reward
|
|
|
|
self.agent_draw_path = agent_draw_path
|
2022-10-15 11:16:13 +02:00
|
|
|
self.clear_path_on_reset = clear_path_on_reset
|
2022-12-09 11:20:15 +01:00
|
|
|
self.path_decay = 0.1
|
|
|
|
|
2022-12-09 17:05:06 +01:00
|
|
|
if isinstance(agent_cls, str):
|
|
|
|
agent_cls = getattr(entities, agent_cls)
|
|
|
|
self.Agent_cls = agent_cls
|
2022-12-09 11:20:15 +01:00
|
|
|
self.agent_attrs = agent_attrs
|
2022-11-13 20:00:15 +01:00
|
|
|
|
2022-12-09 17:05:06 +01:00
|
|
|
self.exception_for_unsupported_collision = exception_for_unsupported_collision
|
|
|
|
|
2022-11-05 17:39:13 +01:00
|
|
|
if value_color_mapper == 'atan':
|
|
|
|
def value_color_mapper(x): return th.atan(x*2)/0.786/2
|
|
|
|
elif value_color_mapper == 'tanh':
|
|
|
|
def value_color_mapper(x): return th.tanh(x*2)/0.762/2
|
2022-10-24 10:08:14 +02:00
|
|
|
self.value_color_mapper = value_color_mapper
|
2022-06-19 15:01:30 +02:00
|
|
|
|
2022-10-14 11:32:37 +02:00
|
|
|
self.max_steps = max_steps
|
|
|
|
self._steps = 0
|
2022-10-15 20:54:16 +02:00
|
|
|
self._has_value_map = False
|
2022-10-14 11:32:37 +02:00
|
|
|
|
2022-06-22 20:32:17 +02:00
|
|
|
self.paused = False
|
|
|
|
self.keypress_timeout = 0
|
2022-07-16 23:25:48 +02:00
|
|
|
self.can_accept_chol = True
|
2022-08-28 17:08:15 +02:00
|
|
|
self._master_rng = random_dont_use.Random()
|
|
|
|
if master_seed == None:
|
|
|
|
master_seed = urandom(12)
|
|
|
|
if master_seed == 'numpy':
|
|
|
|
master_seed = np.random.rand()
|
|
|
|
self._master_rng.seed(master_seed)
|
2022-06-19 15:01:30 +02:00
|
|
|
self.rng = random_dont_use.Random()
|
2022-06-21 15:13:59 +02:00
|
|
|
self._seed(self.env_seed)
|
2022-06-19 15:01:30 +02:00
|
|
|
|
2022-08-27 16:20:39 +02:00
|
|
|
self._init = False
|
|
|
|
|
2022-06-22 20:32:17 +02:00
|
|
|
@property
|
|
|
|
def observation_space(self):
|
2022-08-27 16:20:39 +02:00
|
|
|
if not self._init:
|
|
|
|
self.reset()
|
2022-06-22 20:32:17 +02:00
|
|
|
return self.observable.get_observation_space()
|
2022-06-20 23:10:14 +02:00
|
|
|
|
2022-06-19 15:01:30 +02:00
|
|
|
def _seed(self, seed):
|
2022-06-20 23:10:14 +02:00
|
|
|
if seed == None:
|
2022-08-28 18:37:41 +02:00
|
|
|
seed = self._master_rng.random()
|
2022-06-19 15:01:30 +02:00
|
|
|
self.rng.seed(seed)
|
|
|
|
|
|
|
|
def random(self):
|
|
|
|
return self.rng.random()
|
|
|
|
|
|
|
|
def _ensure_surface(self):
|
2022-08-14 17:21:52 +02:00
|
|
|
if not self.surface or not self.screen:
|
2022-06-19 15:01:30 +02:00
|
|
|
self.surface = pygame.Surface((self.width, self.height))
|
2022-09-20 21:57:41 +02:00
|
|
|
self.path_overlay = pygame.Surface(
|
|
|
|
(self.width, self.height), pygame.SRCALPHA, 32)
|
2022-10-15 20:54:16 +02:00
|
|
|
self.value_overlay = pygame.Surface(
|
|
|
|
(self.width, self.height), pygame.SRCALPHA, 32)
|
2022-06-19 17:20:51 +02:00
|
|
|
if self.visible:
|
|
|
|
self.screen = pygame.display.set_mode(
|
|
|
|
(self.width, self.height))
|
2022-08-14 16:50:21 +02:00
|
|
|
pygame.display.set_caption(self.title)
|
2022-08-07 19:43:17 +02:00
|
|
|
else:
|
|
|
|
self.screen = pygame.Surface((self.width, self.height))
|
2022-06-19 15:01:30 +02:00
|
|
|
|
|
|
|
def _limit_to_unit_circle(self, coords):
|
|
|
|
l_sq = coords[0]**2 + coords[1]**2
|
|
|
|
if l_sq > 1:
|
|
|
|
l = math.sqrt(l_sq)
|
|
|
|
coords = coords[0] / l, coords[1] / l
|
|
|
|
return coords
|
|
|
|
|
|
|
|
def _step_entities(self):
|
|
|
|
for entity in self.entities:
|
|
|
|
entity.step()
|
|
|
|
|
|
|
|
def _step_timers(self):
|
|
|
|
new_timers = []
|
|
|
|
for time_left, func, arg in self.timers:
|
|
|
|
time_left -= 1/self.fps
|
|
|
|
if time_left < 0:
|
|
|
|
func(arg)
|
|
|
|
else:
|
|
|
|
new_timers.append((time_left, func, arg))
|
|
|
|
self.timers = new_timers
|
|
|
|
|
2022-06-19 20:32:37 +02:00
|
|
|
def sq_dist(self, pos1, pos2):
|
|
|
|
return (pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2
|
2022-06-19 15:01:30 +02:00
|
|
|
|
2022-06-19 20:32:37 +02:00
|
|
|
def dist(self, pos1, pos2):
|
|
|
|
return math.sqrt(self.sq_dist(pos1, pos2))
|
2022-06-19 15:01:30 +02:00
|
|
|
|
|
|
|
def _get_aux_reward(self):
|
|
|
|
aux_reward = 0
|
|
|
|
for entity in self.entities:
|
|
|
|
if isinstance(entity, entities.Reward):
|
|
|
|
if entity.avaible:
|
|
|
|
reward = self.aux_reward_max / \
|
2022-06-19 20:32:37 +02:00
|
|
|
(1 + self.sq_dist(entity.pos, self.agent.pos))
|
2022-06-19 15:01:30 +02:00
|
|
|
|
|
|
|
if self.aux_reward_discretize:
|
|
|
|
reward = int(reward*self.aux_reward_discretize*2) / \
|
|
|
|
self.aux_reward_discretize / 2
|
|
|
|
|
|
|
|
aux_reward += reward
|
2022-06-29 18:49:28 +02:00
|
|
|
elif isinstance(entity, entities.Enemy):
|
|
|
|
if entity.radiateDamage:
|
2022-09-13 22:14:17 +02:00
|
|
|
if self.penalty_from_edges:
|
2022-10-25 14:43:29 +02:00
|
|
|
if self.agent.shape != 'circle':
|
|
|
|
raise Exception(
|
|
|
|
'Radiating damage from edge for non-circle Agents not supported')
|
|
|
|
if entity.shape == 'circle':
|
|
|
|
penalty = self.aux_penalty_max / \
|
|
|
|
(1 + self.sq_dist(entity.pos,
|
|
|
|
self.agent.pos) - (entity.radius/max(self.height, self.width))**2 - (self.agent.radius/max(self.height, self.width))**2)
|
|
|
|
elif entity.shape == 'rect':
|
|
|
|
ax, ay = self.agent.pos
|
|
|
|
ex, ey, ex2, ey2 = entity.pos[0], entity.pos[1], entity.pos[0] + \
|
|
|
|
entity.width / \
|
|
|
|
self.width, entity.pos[1] + \
|
|
|
|
entity.height/self.height
|
|
|
|
lx, ly = ax, ay # 'Lotpunkt'
|
|
|
|
if ax < ex:
|
|
|
|
lx = ex
|
|
|
|
elif ax > ex2:
|
|
|
|
lx = ex2
|
|
|
|
if ay < ey:
|
|
|
|
ly = ey
|
|
|
|
elif ay > ey2:
|
|
|
|
ly = ey2
|
|
|
|
penalty = self.aux_penalty_max / \
|
|
|
|
(1 + self.sq_dist((lx, ly),
|
|
|
|
(ax, ay)) - (self.agent.radius/max(self.height, self.width))**2)
|
|
|
|
|
2022-09-13 22:14:17 +02:00
|
|
|
else:
|
|
|
|
penalty = self.aux_penalty_max / \
|
|
|
|
(1 + self.sq_dist(entity.pos, self.agent.pos))
|
2022-06-29 18:49:28 +02:00
|
|
|
|
|
|
|
if self.aux_reward_discretize:
|
|
|
|
penalty = int(penalty*self.aux_reward_discretize*2) / \
|
|
|
|
self.aux_reward_discretize / 2
|
|
|
|
|
|
|
|
aux_reward -= penalty
|
2022-07-19 10:05:28 +02:00
|
|
|
return aux_reward/self.fps
|
2022-06-19 15:01:30 +02:00
|
|
|
|
|
|
|
def step(self, action):
|
2022-08-27 16:20:39 +02:00
|
|
|
if not self._init:
|
|
|
|
self.reset()
|
2022-06-29 12:41:52 +02:00
|
|
|
inp = (action[0]+1)/2, (action[1]+1)/2
|
2022-06-22 20:32:17 +02:00
|
|
|
if self._disturb_next:
|
|
|
|
inp = self._disturb_next
|
|
|
|
self._disturb_next = False
|
2022-06-19 15:01:30 +02:00
|
|
|
if self.limit_inp_to_unit_circle:
|
|
|
|
inp = self._limit_to_unit_circle(((inp[0]-0.5)*2, (inp[1]-0.5)*2))
|
|
|
|
inp = (inp[0]+1)/2, (inp[1]+1)/2
|
|
|
|
self.inp = inp
|
2022-06-22 20:32:17 +02:00
|
|
|
if not self.paused:
|
|
|
|
self._step_timers()
|
|
|
|
self._step_entities()
|
2022-06-19 15:01:30 +02:00
|
|
|
observation = self.observable.get_observation()
|
2022-09-20 21:57:41 +02:00
|
|
|
gotRew = self.new_reward > 0 or self.new_abs_reward > 0
|
2022-12-09 11:20:15 +01:00
|
|
|
self.gotHarm = self.new_reward < 0 or self.new_abs_reward < 0
|
2022-06-19 15:01:30 +02:00
|
|
|
reward, self.new_reward, self.new_abs_reward = self.new_reward / \
|
|
|
|
self.fps + self.new_abs_reward, 0, 0
|
2022-08-17 19:31:15 +02:00
|
|
|
if not self.torus_topology:
|
2022-12-09 17:05:06 +01:00
|
|
|
if self.agent.getTop() < 1 or self.agent.getBottom() > self.height-1 \
|
|
|
|
or self.agent.getLeft() < 1 or self.agent.getRight() > self.width-1:
|
2022-08-17 19:31:15 +02:00
|
|
|
reward -= self.void_damage/self.fps
|
2022-06-19 15:01:30 +02:00
|
|
|
self.score += reward # aux_reward does not count towards the score
|
2022-10-26 17:40:24 +02:00
|
|
|
if self.aux_reward_max or self.aux_penalty_max:
|
2022-06-19 15:01:30 +02:00
|
|
|
reward += self._get_aux_reward()
|
2022-10-14 11:32:37 +02:00
|
|
|
self._steps += 1
|
2022-10-14 17:08:12 +02:00
|
|
|
done = (self.die_on_zero and self.score <= 0) or (self.return_on_score != -
|
|
|
|
1 and self.score > self.return_on_score) or (self._steps == self.max_steps) or (self.terminate_on_reward and gotRew)
|
2022-06-19 17:20:51 +02:00
|
|
|
info = {'score': self.score, 'reward': reward}
|
2022-06-19 20:32:37 +02:00
|
|
|
self._rendered = False
|
2022-08-22 18:53:30 +02:00
|
|
|
if done:
|
|
|
|
self.reset()
|
2022-06-19 20:32:37 +02:00
|
|
|
return observation, reward*self.reward_mult, done, info
|
2022-06-19 15:01:30 +02:00
|
|
|
|
|
|
|
def check_collisions_for(self, entity):
|
|
|
|
for other in self.entities:
|
|
|
|
if other != entity:
|
2022-06-29 18:46:59 +02:00
|
|
|
depth = self._check_collision_between(entity, other)
|
|
|
|
if depth > 0:
|
|
|
|
entity.on_collision(other, depth)
|
|
|
|
other.on_collision(entity, depth)
|
2022-06-19 15:01:30 +02:00
|
|
|
|
2022-06-19 15:26:16 +02:00
|
|
|
def _check_collision_between(self, e1, e2):
|
2022-09-13 22:14:17 +02:00
|
|
|
e = [e1, e2]
|
|
|
|
e.sort(key=lambda x: x.shape)
|
|
|
|
e1, e2 = e
|
2022-06-19 15:26:16 +02:00
|
|
|
shapes = [e1.shape, e2.shape]
|
|
|
|
if shapes == ['circle', 'circle']:
|
2022-06-29 18:46:59 +02:00
|
|
|
dist = math.sqrt(((e1.pos[0]-e2.pos[0])*self.width) ** 2
|
|
|
|
+ ((e1.pos[1]-e2.pos[1])*self.height)**2)
|
|
|
|
return max(0, e1.radius + e2.radius - dist)
|
2022-09-13 22:14:17 +02:00
|
|
|
elif shapes == ['circle', 'rect']:
|
|
|
|
return sum([abs(d) for d in e1._get_crash_force_dir(e2)])
|
2022-06-19 15:26:16 +02:00
|
|
|
else:
|
2022-12-09 17:05:06 +01:00
|
|
|
if self.exception_for_unsupported_collision:
|
|
|
|
raise Exception(
|
|
|
|
'Checking for collision between unsupported shapes: '+str(shapes))
|
|
|
|
return 0.0
|
2022-06-19 15:26:16 +02:00
|
|
|
|
2022-06-19 15:01:30 +02:00
|
|
|
def kill_entity(self, target):
|
|
|
|
newEntities = []
|
|
|
|
for entity in self.entities:
|
|
|
|
if target != entity:
|
|
|
|
newEntities.append(entity)
|
|
|
|
else:
|
|
|
|
del target
|
|
|
|
break
|
|
|
|
self.entities = newEntities
|
|
|
|
|
|
|
|
def setup(self):
|
2022-06-19 15:48:51 +02:00
|
|
|
self.agent.pos = self.start_pos
|
2022-06-20 23:10:14 +02:00
|
|
|
# Expand this function
|
2022-06-19 15:01:30 +02:00
|
|
|
|
2022-12-09 11:20:15 +01:00
|
|
|
def _spawnAgent(self):
|
2022-12-09 17:05:06 +01:00
|
|
|
self.agent = self.Agent_cls(self)
|
2022-12-09 11:20:15 +01:00
|
|
|
self.agent.draw_path = self.agent_draw_path
|
|
|
|
for k, v in self.agent_attrs.items():
|
|
|
|
setattr(self.agent, k, v)
|
|
|
|
|
2022-10-16 17:51:05 +02:00
|
|
|
def reset(self, force_reset_path=False):
|
2022-06-19 15:01:30 +02:00
|
|
|
pygame.init()
|
2022-08-27 16:20:39 +02:00
|
|
|
self._init = True
|
2022-10-14 11:32:37 +02:00
|
|
|
self._steps = 0
|
2022-10-15 20:54:16 +02:00
|
|
|
self._has_value_map = False
|
2022-06-22 13:08:23 +02:00
|
|
|
self._seed(self.env_seed)
|
2022-06-19 20:32:37 +02:00
|
|
|
self._rendered = False
|
2022-06-22 20:32:17 +02:00
|
|
|
self._disturb_next = False
|
2022-06-19 15:01:30 +02:00
|
|
|
self.inp = (0.5, 0.5)
|
|
|
|
# will get rescaled acording to fps (=reward per second)
|
|
|
|
self.new_reward = 0
|
|
|
|
self.new_abs_reward = 0 # will not get rescaled. should be used for one-time rewards
|
2022-12-09 11:20:15 +01:00
|
|
|
self.gotHarm = False
|
2022-08-17 19:31:15 +02:00
|
|
|
self.score = self.start_score
|
2022-06-19 15:01:30 +02:00
|
|
|
self.entities = []
|
|
|
|
self.timers = []
|
2022-12-09 11:20:15 +01:00
|
|
|
self._spawnAgent()
|
2022-06-19 15:01:30 +02:00
|
|
|
self.setup()
|
|
|
|
self.entities.append(self.agent) # add it last, will be drawn on top
|
2022-08-22 18:53:30 +02:00
|
|
|
self.observable.reset()
|
2022-10-16 17:51:05 +02:00
|
|
|
if self.clear_path_on_reset or force_reset_path:
|
2022-10-15 11:16:13 +02:00
|
|
|
self._reset_paths()
|
2022-06-19 17:20:51 +02:00
|
|
|
return self.observable.get_observation()
|
2022-06-19 15:01:30 +02:00
|
|
|
|
2022-09-23 22:22:51 +02:00
|
|
|
def _reset_paths(self):
|
|
|
|
self.path_overlay = pygame.Surface(
|
|
|
|
(self.width, self.height), pygame.SRCALPHA, 32)
|
|
|
|
|
2022-06-19 15:01:30 +02:00
|
|
|
def _draw_entities(self):
|
|
|
|
for entity in self.entities:
|
|
|
|
entity.draw()
|
|
|
|
|
2022-10-24 10:08:14 +02:00
|
|
|
def _invalidate_value_map(self):
|
|
|
|
self._has_value_map = False
|
|
|
|
|
|
|
|
def _draw_values(self, value_func, static=True, resolution=64, color_depth=224, color_mapper=None):
|
|
|
|
if (not (static and self._has_value_map)):
|
2022-10-15 20:54:16 +02:00
|
|
|
agentpos = self.agent.pos
|
|
|
|
agentspeed = self.agent.speed
|
|
|
|
self.agent.speed = (0, 0)
|
|
|
|
self.value_overlay = pygame.Surface(
|
|
|
|
(self.width, self.height), pygame.SRCALPHA, 32)
|
|
|
|
obs = []
|
|
|
|
for i in range(resolution):
|
|
|
|
for j in range(resolution):
|
2022-10-19 10:41:49 +02:00
|
|
|
x, y = (i+0.5)/resolution, (j+0.5)/resolution
|
2022-10-15 20:54:16 +02:00
|
|
|
self.agent.pos = x, y
|
|
|
|
ob = self.observable.get_observation()
|
|
|
|
obs.append(ob)
|
|
|
|
self.agent.pos = agentpos
|
|
|
|
self.agent.speed = agentspeed
|
|
|
|
|
2022-10-17 23:00:01 +02:00
|
|
|
V = value_func(th.Tensor(np.array(obs)))
|
2022-10-16 18:58:32 +02:00
|
|
|
V /= max(V.max(), -1*V.min())*2
|
2022-10-24 10:08:14 +02:00
|
|
|
if color_mapper != None:
|
|
|
|
V = color_mapper(V)
|
2022-11-05 17:39:13 +01:00
|
|
|
V += 0.5
|
2022-10-24 10:08:14 +02:00
|
|
|
|
2022-10-15 20:54:16 +02:00
|
|
|
c = 0
|
|
|
|
for i in range(resolution):
|
|
|
|
for j in range(resolution):
|
2022-10-16 19:15:23 +02:00
|
|
|
v = V[c].item()
|
2022-10-15 20:54:16 +02:00
|
|
|
c += 1
|
2022-10-16 18:58:32 +02:00
|
|
|
col = [int((1-v)*color_depth),
|
|
|
|
int(v*color_depth), 0, color_depth]
|
2022-10-15 20:54:16 +02:00
|
|
|
x, y = i*(self.width/resolution), j * \
|
|
|
|
(self.height/resolution)
|
2022-10-16 18:58:32 +02:00
|
|
|
rect = pygame.Rect(x, y, int(self.width/resolution)+1,
|
|
|
|
int(self.height/resolution)+1)
|
2022-10-15 20:54:16 +02:00
|
|
|
pygame.draw.rect(self.value_overlay, col,
|
|
|
|
rect, width=0)
|
2022-10-16 18:58:32 +02:00
|
|
|
self.surface.blit(self.value_overlay, (0, 0))
|
|
|
|
self._has_value_map = True
|
2022-10-15 20:54:16 +02:00
|
|
|
|
2022-06-19 15:01:30 +02:00
|
|
|
def _draw_observable(self, forceDraw=False):
|
2022-08-07 19:43:17 +02:00
|
|
|
if self.draw_observable and (self.visible or forceDraw):
|
2022-06-19 15:01:30 +02:00
|
|
|
self.observable.draw()
|
|
|
|
|
|
|
|
def _draw_joystick(self, forceDraw=False):
|
2022-08-07 19:43:17 +02:00
|
|
|
if self.draw_joystick and (self.visible or forceDraw):
|
2022-06-19 15:01:30 +02:00
|
|
|
x, y = self.inp
|
2022-06-22 20:32:17 +02:00
|
|
|
bigcol = (100, 100, 100)
|
|
|
|
smolcol = (100, 100, 100)
|
|
|
|
if self._disturb_next:
|
|
|
|
smolcol = (255, 255, 255)
|
|
|
|
pygame.draw.circle(self.screen, bigcol, (50 +
|
|
|
|
self.joystick_offset[0], 50+self.joystick_offset[1]), 50, width=1)
|
|
|
|
pygame.draw.circle(self.screen, smolcol, (20+int(60*x) +
|
|
|
|
self.joystick_offset[0], 20+int(60*y)+self.joystick_offset[1]), 20, width=0)
|
|
|
|
|
2022-11-01 16:25:47 +01:00
|
|
|
def _draw_confidence_ellipse(self, chol, forceDraw=False, seconds=0.1):
|
2022-09-16 11:49:48 +02:00
|
|
|
# The 'seconds'-parameter only really makes sense, when using control_type='SPEED',
|
|
|
|
# you can still use it to scale the cov-ellipse when using control_type='ACC',
|
|
|
|
# but it's relation to 'seconds' is no longer there...
|
2022-08-07 19:43:17 +02:00
|
|
|
if self.draw_confidence_ellipse and (self.visible or forceDraw):
|
2022-07-17 01:00:55 +02:00
|
|
|
col = (255, 255, 255)
|
2022-11-01 16:25:47 +01:00
|
|
|
f = seconds*self.speed_fac*self.fps*max(self.height, self.width)
|
2022-07-17 01:00:55 +02:00
|
|
|
|
|
|
|
while len(chol.shape) > 2:
|
|
|
|
chol = chol[0]
|
|
|
|
if chol.shape != (2, 2):
|
|
|
|
chol = th.diag_embed(chol)
|
|
|
|
if len(chol.shape) != 2:
|
|
|
|
chol = chol[0]
|
|
|
|
cov = chol.T @ chol
|
|
|
|
|
|
|
|
L, V = th.linalg.eig(cov)
|
|
|
|
L, V = L.real, V.real
|
2022-11-01 16:25:47 +01:00
|
|
|
l1, l2 = int(abs(math.sqrt(L[0].item())*f)) + \
|
|
|
|
1, int(abs(math.sqrt(L[1].item())*f))+1
|
2022-11-01 16:14:56 +01:00
|
|
|
|
|
|
|
if l1 >= l2:
|
|
|
|
w, h = l1, l2
|
|
|
|
run, rise = V[0][0], V[0][1]
|
|
|
|
else:
|
|
|
|
w, h = l2, l1
|
|
|
|
run, rise = V[1][0], V[1][1]
|
|
|
|
|
|
|
|
ang = (math.atan(rise/run))/(2*math.pi)*360
|
|
|
|
|
|
|
|
# print(w, h, (run, rise, ang))
|
2022-07-17 01:00:55 +02:00
|
|
|
|
|
|
|
x, y = self.agent.pos
|
|
|
|
x, y = x*self.width, y*self.height
|
|
|
|
rect = pygame.Rect((x-w/2, y-h/2, w, h))
|
|
|
|
shape_surface = pygame.Surface(rect.size, pygame.SRCALPHA)
|
|
|
|
pygame.draw.ellipse(shape_surface, col,
|
|
|
|
(0, 0, *rect.size), 1)
|
|
|
|
rotated_surf = pygame.transform.rotate(shape_surface, ang)
|
|
|
|
self.screen.blit(rotated_surf, rotated_surf.get_rect(
|
|
|
|
center=rect.center))
|
2022-07-16 23:25:48 +02:00
|
|
|
|
2022-12-09 11:20:15 +01:00
|
|
|
def _draw_paths(self):
|
|
|
|
if self.path_decay != 0.0:
|
|
|
|
s = pygame.Surface((self.width, self.height))
|
|
|
|
s.set_alpha(soft_int(255*self.path_decay/self.fps))
|
|
|
|
s.fill((0, 0, 0))
|
|
|
|
self.path_overlay.blit(s, (0, 0))
|
|
|
|
self.surface.blit(self.path_overlay, (0, 0))
|
|
|
|
|
2022-06-22 20:32:17 +02:00
|
|
|
def _handle_user_input(self):
|
2022-06-21 15:13:59 +02:00
|
|
|
for event in pygame.event.get():
|
|
|
|
pass
|
|
|
|
keys = pygame.key.get_pressed()
|
|
|
|
if self.keypress_timeout == 0:
|
2022-06-22 20:32:17 +02:00
|
|
|
self.keypress_timeout = int(self.fps/5)
|
2022-06-21 15:13:59 +02:00
|
|
|
if keys[pygame.K_m]:
|
|
|
|
self.draw_entities = not self.draw_entities
|
2022-07-17 01:00:55 +02:00
|
|
|
elif keys[pygame.K_c]:
|
|
|
|
self.draw_confidence_ellipse = not self.draw_confidence_ellipse
|
2022-06-22 13:08:23 +02:00
|
|
|
elif keys[pygame.K_r]:
|
|
|
|
self.reset()
|
2022-10-17 23:00:01 +02:00
|
|
|
elif keys[pygame.K_t]:
|
|
|
|
self._reset_paths()
|
2022-06-22 20:32:17 +02:00
|
|
|
elif keys[pygame.K_p]:
|
|
|
|
self.paused = not self.paused
|
2022-06-21 15:13:59 +02:00
|
|
|
else:
|
|
|
|
self.keypress_timeout = 0
|
|
|
|
else:
|
|
|
|
self.keypress_timeout -= 1
|
|
|
|
|
2022-06-22 20:32:17 +02:00
|
|
|
# keys, that can be hold down to continously trigger them
|
|
|
|
if keys[pygame.K_q]:
|
|
|
|
self._disturb_next = (
|
|
|
|
random_dont_use.random(), random_dont_use.random())
|
|
|
|
elif keys[pygame.K_w]:
|
|
|
|
self._disturb_next = (0.5, 0.0)
|
|
|
|
elif keys[pygame.K_a]:
|
|
|
|
self._disturb_next = (0.0, 0.5)
|
|
|
|
elif keys[pygame.K_s]:
|
|
|
|
self._disturb_next = (0.5, 1.0)
|
|
|
|
elif keys[pygame.K_d]:
|
|
|
|
self._disturb_next = (1.0, 0.5)
|
|
|
|
|
2022-10-15 20:54:16 +02:00
|
|
|
def render(self, mode='human', dont_show=False, chol=None, value_func=None, values_static=True):
|
2022-08-14 17:21:52 +02:00
|
|
|
if mode == 'human':
|
|
|
|
self._handle_user_input()
|
2022-08-14 17:53:51 +02:00
|
|
|
self.visible = self.visible or not dont_show
|
2022-06-19 15:01:30 +02:00
|
|
|
self._ensure_surface()
|
|
|
|
pygame.draw.rect(self.surface, (0, 0, 0),
|
|
|
|
pygame.Rect(0, 0, self.width, self.height))
|
2022-10-16 18:58:32 +02:00
|
|
|
if value_func != None:
|
2022-10-24 10:08:14 +02:00
|
|
|
self._draw_values(value_func, values_static,
|
|
|
|
color_mapper=self.value_color_mapper)
|
2022-12-09 11:20:15 +01:00
|
|
|
self._draw_paths()
|
2022-06-19 20:32:37 +02:00
|
|
|
if self.draw_entities:
|
|
|
|
self._draw_entities()
|
|
|
|
else:
|
|
|
|
self.agent.draw()
|
|
|
|
self._rendered = True
|
2022-08-07 19:43:17 +02:00
|
|
|
if mode == 'human' and dont_show:
|
2022-06-19 20:32:37 +02:00
|
|
|
return
|
2022-10-15 20:54:16 +02:00
|
|
|
self.screen.blit(self.surface, (0, 0))
|
2022-08-07 19:43:17 +02:00
|
|
|
self._draw_observable(forceDraw=mode != 'human')
|
|
|
|
self._draw_joystick(forceDraw=mode != 'human')
|
2022-07-16 23:25:48 +02:00
|
|
|
if chol != None:
|
2022-08-07 19:43:17 +02:00
|
|
|
self._draw_confidence_ellipse(chol, forceDraw=mode != 'human')
|
|
|
|
if self.visible and mode == 'human':
|
2022-06-19 17:20:51 +02:00
|
|
|
pygame.display.update()
|
2022-08-07 19:43:17 +02:00
|
|
|
if mode != 'human':
|
|
|
|
return pygame.surfarray.array3d(self.screen)
|
2022-06-19 15:01:30 +02:00
|
|
|
|
|
|
|
def close(self):
|
|
|
|
pygame.display.quit()
|
|
|
|
pygame.quit()
|
2022-06-19 15:59:23 +02:00
|
|
|
|
|
|
|
|
2022-12-06 12:16:42 +01:00
|
|
|
class ColumbusConfigDefined(ColumbusEnv):
|
|
|
|
# Allows defining Columbus Environments using dicts.
|
|
|
|
# Intended to be used in combination with cw2 configuration.
|
|
|
|
# Look into humanPlayer to see how this is supposed to be interfaced with.
|
|
|
|
|
|
|
|
def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw):
|
|
|
|
super().__init__(
|
|
|
|
observable=observable, fps=fps, env_seed=env_seed, **kw)
|
|
|
|
self.entities_definitions = entities
|
2022-12-06 12:25:20 +01:00
|
|
|
self.start_pos = self.conv_unit(self.start_pos[0], target='em', axis='x'), self.conv_unit(
|
|
|
|
self.start_pos[1], target='em', axis='y')
|
2022-12-06 12:16:42 +01:00
|
|
|
|
|
|
|
def is_unit(self, s):
|
|
|
|
if type(s) in [int, float]:
|
|
|
|
return True
|
2022-12-09 17:05:06 +01:00
|
|
|
if s.replace('.', '', 1).replace('-', '0', 1).isdigit():
|
2022-12-06 12:16:42 +01:00
|
|
|
return True
|
|
|
|
num, unit = s[:-2], s[-2:]
|
2022-12-06 12:52:32 +01:00
|
|
|
if unit in ['px', 'em', 'rx', 'ry', 'ct', 'au']:
|
2022-12-09 17:05:06 +01:00
|
|
|
if num.replace('.', '', 1).replace('-', '0', 1).isdigit():
|
2022-12-06 12:16:42 +01:00
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
def conv_unit(self, s, target='px', axis='x'):
|
|
|
|
assert self.is_unit(s)
|
|
|
|
if type(s) in [int, float]:
|
|
|
|
return s
|
|
|
|
if s.replace('.', '', 1).isdigit():
|
|
|
|
if target == 'px':
|
|
|
|
return int(s)
|
|
|
|
return float(s)
|
|
|
|
num, unit = s[:-2], s[-2:]
|
|
|
|
num = float(num)
|
|
|
|
if unit == 'rx':
|
|
|
|
unit = 'px'
|
|
|
|
axis = 'x'
|
|
|
|
elif unit == 'ry':
|
|
|
|
unit = 'px'
|
|
|
|
axis = 'y'
|
|
|
|
if unit == 'em':
|
|
|
|
em = num
|
|
|
|
elif unit == 'px':
|
|
|
|
em = num / ({'x': self.width, 'y': self.height}[axis])
|
2022-12-06 12:52:32 +01:00
|
|
|
elif unit == 'au':
|
|
|
|
em = num * 36 / ({'x': self.width, 'y': self.height}[axis])
|
2022-12-06 12:16:42 +01:00
|
|
|
elif unit == 'ct':
|
|
|
|
em = num / 100
|
|
|
|
else:
|
|
|
|
raise Exception('Conversion not implemented')
|
|
|
|
|
|
|
|
if target == 'em':
|
|
|
|
return em
|
|
|
|
elif target == 'px':
|
|
|
|
return int(em * ({'x': self.width, 'y': self.height}[axis]))
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i, e in enumerate(self.entities_definitions):
|
|
|
|
Entity = getattr(entities, e['type'])
|
|
|
|
for i in range(e.get('num', 1) + int(self.random()*(0.99+e.get('num_rand', 0)))):
|
|
|
|
entity = Entity(self)
|
|
|
|
conf = {k: v for k, v in e.items() if str(
|
|
|
|
k) not in ['num', 'num_rand', 'type']}
|
|
|
|
|
|
|
|
for k, v_raw in conf.items():
|
|
|
|
if k == 'pos':
|
|
|
|
v = self.conv_unit(v_raw[0], target='em', axis='x'), self.conv_unit(
|
|
|
|
v_raw[1], target='em', axis='y')
|
|
|
|
elif k in ['width', 'height', 'radius']:
|
|
|
|
v = self.conv_unit(
|
|
|
|
v_raw, target='px', axis='y' if k == 'height' else 'x')
|
|
|
|
else:
|
|
|
|
v = v_raw
|
|
|
|
if k.endswith('_rand'):
|
2022-12-09 17:05:06 +01:00
|
|
|
if isinstance(v, int):
|
|
|
|
n = k.replace('_rand', '')
|
|
|
|
cur = getattr(
|
|
|
|
entity, n)
|
|
|
|
inc = int((v+0.99)*self.random())
|
|
|
|
setattr(entity, n, cur + inc)
|
|
|
|
elif isinstance(v, float):
|
|
|
|
n = k.replace('_randf', '')
|
|
|
|
cur = getattr(
|
|
|
|
entity, n)
|
|
|
|
inc = v*self.random()
|
|
|
|
setattr(entity, n, cur + inc)
|
|
|
|
elif isinstance(v, list):
|
|
|
|
for vi, ve in enumerate(v):
|
|
|
|
if isinstance(v, int):
|
|
|
|
n = k.replace('_rand', '')
|
|
|
|
cur = getattr(
|
|
|
|
entity, n)
|
|
|
|
cur[vi] = int((v+0.99)*self.random())
|
|
|
|
setattr(entity, n, cur)
|
|
|
|
elif isinstance(v, float):
|
|
|
|
n = k.replace('_randf', '')
|
|
|
|
cur = getattr(
|
|
|
|
entity, n)
|
|
|
|
cur[vi] = v*self.random()
|
|
|
|
setattr(entity, n, cur)
|
2022-12-06 12:16:42 +01:00
|
|
|
elif k.endswith('_randf'):
|
|
|
|
n = k.replace('_randf', '')
|
|
|
|
cur = getattr(
|
|
|
|
entity, n)
|
|
|
|
inc = v*self.random()
|
|
|
|
setattr(entity, n, cur + inc)
|
|
|
|
else:
|
|
|
|
setattr(entity, k, v)
|
|
|
|
|
|
|
|
self.entities.append(entity)
|
|
|
|
|
|
|
|
###
|
|
|
|
# Custom Env Definitions
|
|
|
|
|
|
|
|
|
2022-06-19 15:59:23 +02:00
|
|
|
class ColumbusTest3_1(ColumbusEnv):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.CnnObservable(out_width=48, out_height=48), fps=30, aux_reward_max=1, **kw):
|
2022-06-19 16:06:57 +02:00
|
|
|
super(ColumbusTest3_1, self).__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, fps=fps, env_seed=3.1, aux_reward_max=aux_reward_max, **kw)
|
2022-06-19 17:20:51 +02:00
|
|
|
self.start_pos = [0.6, 0.3]
|
|
|
|
self.score = 0
|
2022-06-19 20:32:37 +02:00
|
|
|
|
2022-06-20 23:10:14 +02:00
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(18):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = self.random()*40+50
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(3):
|
|
|
|
enemy = entities.FlyingChaser(self)
|
|
|
|
enemy.chase_acc = self.random()*0.4*0.3 # *0.6+0.5
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(0):
|
|
|
|
reward = entities.TimeoutReward(self)
|
|
|
|
self.entities.append(reward)
|
|
|
|
for i in range(1):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
self.entities.append(reward)
|
2022-06-19 20:32:37 +02:00
|
|
|
|
2022-06-20 23:10:14 +02:00
|
|
|
|
2022-09-13 22:14:17 +02:00
|
|
|
class ColumbusTestRect(ColumbusEnv):
|
2022-09-13 22:25:29 +02:00
|
|
|
def __init__(self, observable=observables.RayObservable(), fps=30, aux_reward_max=1, **kw):
|
2022-09-13 22:14:17 +02:00
|
|
|
super().__init__(
|
|
|
|
observable=observable, fps=fps, env_seed=3.3, aux_reward_max=aux_reward_max, controll_type='ACC', **kw)
|
|
|
|
self.start_pos = [0.5, 0.5]
|
|
|
|
self.score = 0
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(1):
|
|
|
|
enemy = entities.RectBarrier(self)
|
|
|
|
enemy.width = self.random()*40+50
|
|
|
|
enemy.height = self.random()*40+50
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(1):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = self.random()*40+50
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(1):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
self.entities.append(reward)
|
|
|
|
|
|
|
|
|
2022-06-20 23:10:14 +02:00
|
|
|
class ColumbusTestRay(ColumbusTest3_1):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.RayObservable(), hide_map=False, fps=30, **kw):
|
2022-06-19 20:32:37 +02:00
|
|
|
super(ColumbusTestRay, self).__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, fps=fps, **kw)
|
2022-06-20 23:10:14 +02:00
|
|
|
self.draw_entities = not hide_map
|
|
|
|
|
|
|
|
|
|
|
|
class ColumbusRayDrone(ColumbusTestRay):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.RayObservable(), hide_map=False, fps=30, **kw):
|
2022-06-20 23:10:14 +02:00
|
|
|
super(ColumbusRayDrone, self).__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, hide_map=hide_map, fps=fps, **kw)
|
2022-06-20 23:10:14 +02:00
|
|
|
self.controll_type = 'ACC'
|
|
|
|
self.agent_drag = 0.02
|
|
|
|
|
|
|
|
|
2022-09-13 22:14:17 +02:00
|
|
|
class ColumbusDemoEnv3_1(ColumbusEnv):
|
|
|
|
def __init__(self, observable=observables.Observable(), fps=30, aux_reward_max=1, **kw):
|
|
|
|
super().__init__(
|
|
|
|
observable=observable, fps=fps, env_seed=3.1, aux_reward_max=aux_reward_max, controll_type='ACC', agent_drag=0.05, **kw)
|
|
|
|
self.start_pos = [0.6, 0.3]
|
|
|
|
self.score = 0
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(18):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = self.random()*40+50
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(0):
|
|
|
|
enemy = entities.FlyingChaser(self)
|
|
|
|
enemy.chase_acc = self.random()*0.4*0.3 # *0.6+0.5
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(1):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
self.entities.append(reward)
|
|
|
|
|
|
|
|
|
|
|
|
class ColumbusDemoEnv2_7(ColumbusEnv):
|
|
|
|
def __init__(self, observable=observables.Observable(), fps=30, aux_reward_max=1, **kw):
|
|
|
|
super().__init__(
|
|
|
|
observable=observable, fps=fps, env_seed=2.7, aux_reward_max=aux_reward_max, controll_type='ACC', agent_drag=0.05, **kw)
|
|
|
|
self.start_pos = [0.6, 0.3]
|
|
|
|
self.score = 0
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(12):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = self.random()*30+40
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(3):
|
|
|
|
enemy = entities.FlyingChaser(self)
|
|
|
|
enemy.chase_acc = self.random()*0.4*0.3 # *0.6+0.5
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(1):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
self.entities.append(reward)
|
|
|
|
|
|
|
|
|
|
|
|
class ColumbusDemoEnvFootball(ColumbusEnv):
|
|
|
|
def __init__(self, observable=observables.Observable(), fps=30, walkingOpponent=0, flyingOpponent=0, **kw):
|
|
|
|
super().__init__(
|
|
|
|
observable=observable, fps=fps, env_seed=1.23, **kw)
|
|
|
|
self.start_pos = [0.5, 0.5]
|
|
|
|
self.score = 0
|
|
|
|
self.walkingOpponents = walkingOpponent
|
|
|
|
self.flyingOpponents = flyingOpponent
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(8):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = self.random()*40+50
|
|
|
|
self.entities.append(enemy)
|
|
|
|
ball = entities.Ball(self)
|
|
|
|
self.entities.append(ball)
|
|
|
|
self.entities.append(entities.TeleportingGoal(self))
|
|
|
|
for i in range(self.walkingOpponents):
|
|
|
|
self.entities.append(entities.WalkingFootballPlayer(self, ball))
|
|
|
|
for i in range(self.flyingOpponents):
|
|
|
|
self.entities.append(entities.FlyingFootballPlayer(self, ball))
|
|
|
|
|
|
|
|
|
2022-06-20 23:10:14 +02:00
|
|
|
class ColumbusCandyland(ColumbusEnv):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.RayObservable(chans=[entities.Reward, entities.Void], num_rays=16, include_rand=True), hide_map=False, fps=30, env_seed=None, **kw):
|
2022-06-20 23:10:14 +02:00
|
|
|
super(ColumbusCandyland, self).__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, fps=fps, env_seed=env_seed, **kw)
|
2022-06-20 23:10:14 +02:00
|
|
|
self.draw_entities = not hide_map
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(0):
|
|
|
|
reward = entities.TimeoutReward(self)
|
|
|
|
reward.radius = 30
|
|
|
|
self.entities.append(reward)
|
|
|
|
for i in range(2):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
reward.radius = 30
|
|
|
|
self.entities.append(reward)
|
|
|
|
|
|
|
|
|
2022-06-26 10:56:14 +02:00
|
|
|
class ColumbusCandyland_Aux10(ColumbusCandyland):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, fps=30, aux_reward_max=10, **kw):
|
|
|
|
super(ColumbusCandyland_Aux10, self).__init__(
|
|
|
|
fps=fps, aux_reward_max=aux_reward_max, **kw)
|
2022-06-26 10:56:14 +02:00
|
|
|
|
|
|
|
|
2022-06-20 23:10:14 +02:00
|
|
|
class ColumbusEasyObstacles(ColumbusEnv):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.RayObservable(num_rays=16), hide_map=False, fps=30, env_seed=None, aux_reward_max=10, **kw):
|
2022-06-20 23:10:14 +02:00
|
|
|
super(ColumbusEasyObstacles, self).__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, fps=fps, env_seed=env_seed, aux_reward_max=aux_reward_max, **kw)
|
2022-06-19 22:58:49 +02:00
|
|
|
self.draw_entities = not hide_map
|
2022-06-20 23:10:14 +02:00
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(5):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = 30 + self.random()*70
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(2):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
reward.radius = 30
|
|
|
|
self.entities.append(reward)
|
|
|
|
for i in range(1):
|
|
|
|
enemy = entities.WalkingChaser(self)
|
2022-06-21 15:13:59 +02:00
|
|
|
enemy.chase_speed = 0.20
|
|
|
|
self.entities.append(enemy)
|
|
|
|
|
|
|
|
|
|
|
|
class ColumbusEasierObstacles(ColumbusEnv):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.RayObservable(num_rays=16), hide_map=False, fps=30, env_seed=None, aux_reward_max=10, **kw):
|
2022-06-21 15:13:59 +02:00
|
|
|
super(ColumbusEasierObstacles, self).__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, fps=fps, env_seed=env_seed, aux_reward_max=aux_reward_max, **kw)
|
2022-06-21 15:13:59 +02:00
|
|
|
self.draw_entities = not hide_map
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(5):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = 30 + self.random()*70
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(3):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
reward.radius = 30
|
|
|
|
reward.reward *= 2
|
|
|
|
self.entities.append(reward)
|
|
|
|
for i in range(1):
|
|
|
|
enemy = entities.WalkingChaser(self)
|
|
|
|
enemy.chase_speed = 0.20
|
2022-06-20 23:10:14 +02:00
|
|
|
self.entities.append(enemy)
|
|
|
|
|
|
|
|
|
2022-08-07 18:03:27 +02:00
|
|
|
class ColumbusComp(ColumbusEnv):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.CompositionalObservable([observables.RayObservable(num_rays=6, chans=[entities.Enemy]), observables.StateObservable(coordsAgent=True, speedAgent=False, coordsRelativeToAgent=False, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=False, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=False, include_rand=True)]), hide_map=False, fps=30, env_seed=None, aux_reward_max=10, **kw):
|
2022-08-07 18:03:27 +02:00
|
|
|
super().__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, fps=fps, env_seed=env_seed, aux_reward_max=aux_reward_max, **kw)
|
2022-08-07 18:03:27 +02:00
|
|
|
self.draw_entities = not hide_map
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(5):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = 30 + self.random()*70
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(3):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
reward.radius = 30
|
|
|
|
reward.reward *= 2
|
|
|
|
self.entities.append(reward)
|
|
|
|
|
|
|
|
|
2022-08-14 17:48:53 +02:00
|
|
|
class ColumbusSingle(ColumbusEnv):
|
2022-08-17 19:31:15 +02:00
|
|
|
def __init__(self, observable=observables.CompositionalObservable([observables.RayObservable(num_rays=6, chans=[entities.Enemy]), observables.StateObservable(coordsAgent=False, speedAgent=False, coordsRelativeToAgent=True, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=False, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=False, include_rand=True)]), hide_map=False, fps=30, env_seed=None, aux_reward_max=1, enemy_damage=1, reward_reward=25, void_damage=1, **kw):
|
2022-08-14 17:48:53 +02:00
|
|
|
super().__init__(
|
2022-08-16 20:11:54 +02:00
|
|
|
observable=observable, fps=fps, env_seed=env_seed, aux_reward_max=aux_reward_max, void_damage=void_damage, **kw)
|
2022-08-14 17:48:53 +02:00
|
|
|
self.draw_entities = not hide_map
|
2022-08-16 20:03:12 +02:00
|
|
|
self._enemy_damage = enemy_damage
|
|
|
|
self._reward_reward = reward_reward
|
2022-08-14 17:48:53 +02:00
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(4 + math.floor(self.random()*4)):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = 30 + self.random()*70
|
2022-08-16 20:03:12 +02:00
|
|
|
enemy.damage = self._enemy_damage
|
2022-08-14 17:48:53 +02:00
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(1):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
reward.radius = 30
|
2022-08-16 20:03:12 +02:00
|
|
|
reward.reward = self._reward_reward
|
2022-08-14 17:48:53 +02:00
|
|
|
self.entities.append(reward)
|
|
|
|
|
|
|
|
|
2022-06-21 22:29:10 +02:00
|
|
|
class ColumbusJustState(ColumbusEnv):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.StateObservable(), fps=30, num_enemies=0, num_rewards=1, env_seed=None, aux_reward_max=10, **kw):
|
2022-06-21 22:29:10 +02:00
|
|
|
super(ColumbusJustState, self).__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, fps=fps, env_seed=env_seed, aux_reward_max=aux_reward_max, **kw)
|
2022-06-29 12:41:52 +02:00
|
|
|
self.num_enemies = num_enemies
|
|
|
|
self.num_rewards = num_rewards
|
2022-06-20 23:10:14 +02:00
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
2022-06-29 12:41:52 +02:00
|
|
|
for i in range(self.num_enemies):
|
2022-06-20 23:10:14 +02:00
|
|
|
enemy = entities.FlyingChaser(self)
|
|
|
|
enemy.chase_acc = self.random()*0.4+0.3 # *0.6+0.5
|
|
|
|
self.entities.append(enemy)
|
2022-06-29 12:41:52 +02:00
|
|
|
for i in range(self.num_rewards):
|
2022-06-20 23:10:14 +02:00
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
reward.radius = 30
|
|
|
|
self.entities.append(reward)
|
|
|
|
|
|
|
|
|
2022-06-21 22:29:10 +02:00
|
|
|
class ColumbusStateWithBarriers(ColumbusEnv):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.StateObservable(coordsAgent=True, speedAgent=False, coordsRelativeToAgent=False, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=False, include_rand=True), fps=30, env_seed=3.141, num_enemys=0, num_barriers=3, aux_reward_max=10, **kw):
|
2022-06-21 22:29:10 +02:00
|
|
|
super(ColumbusStateWithBarriers, self).__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, fps=fps, env_seed=env_seed, aux_reward_max=aux_reward_max, **kw)
|
2022-06-20 23:10:14 +02:00
|
|
|
self.start_pos = (0.5, 0.5)
|
2022-06-29 12:41:52 +02:00
|
|
|
self.num_barriers = num_barriers
|
|
|
|
self.num_enemys = num_enemys
|
2022-06-20 23:10:14 +02:00
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
2022-06-29 12:41:52 +02:00
|
|
|
for i in range(self.num_barriers):
|
2022-06-20 23:10:14 +02:00
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = self.random()*25+75
|
|
|
|
self.entities.append(enemy)
|
2022-06-29 12:41:52 +02:00
|
|
|
for i in range(self.num_enemys):
|
2022-06-20 23:10:14 +02:00
|
|
|
enemy = entities.FlyingChaser(self)
|
2022-06-22 13:08:23 +02:00
|
|
|
enemy.chase_acc = 0.55 # *0.6+0.5
|
2022-06-20 23:10:14 +02:00
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(1):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
reward.radius = 30
|
|
|
|
self.entities.append(reward)
|
|
|
|
|
|
|
|
|
2022-08-20 21:32:34 +02:00
|
|
|
class ColumbusCompassWithBarriers(ColumbusEnv):
|
|
|
|
def __init__(self, observable=observables.CompassObservable(coordsRewards=True), fps=30, env_seed=3.141, num_enemys=0, num_barriers=3, aux_reward_max=10, **kw):
|
|
|
|
super().__init__(
|
|
|
|
observable=observable, fps=fps, env_seed=env_seed, aux_reward_max=aux_reward_max, **kw)
|
|
|
|
self.start_pos = (0.5, 0.5)
|
|
|
|
self.num_barriers = num_barriers
|
|
|
|
self.num_enemys = num_enemys
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(self.num_barriers):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = self.random()*25+75
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(self.num_enemys):
|
|
|
|
enemy = entities.FlyingChaser(self)
|
|
|
|
enemy.chase_acc = 0.55 # *0.6+0.5
|
|
|
|
self.entities.append(enemy)
|
|
|
|
for i in range(1):
|
|
|
|
reward = entities.TeleportingReward(self)
|
|
|
|
reward.radius = 30
|
|
|
|
self.entities.append(reward)
|
|
|
|
|
|
|
|
|
2022-06-22 20:32:17 +02:00
|
|
|
class ColumbusTrivialRay(ColumbusStateWithBarriers):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.RayObservable(num_rays=8, ray_len=512), hide_map=False, fps=30, **kw):
|
2022-06-22 20:32:17 +02:00
|
|
|
super(ColumbusTrivialRay, self).__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, fps=fps, num_chasers=0, **kw)
|
2022-06-22 20:32:17 +02:00
|
|
|
self.draw_entities = not hide_map
|
|
|
|
|
2022-06-22 16:05:30 +02:00
|
|
|
|
2022-06-30 14:42:56 +02:00
|
|
|
class ColumbusFootball(ColumbusEnv):
|
2022-08-15 17:16:18 +02:00
|
|
|
def __init__(self, observable=observables.RayObservable(num_rays=16, chans=[entities.Goal, entities.Ball, entities.Barrier]), fps=30, walkingOpponent=0, flyingOpponent=0, **kw):
|
2022-06-30 14:42:56 +02:00
|
|
|
super(ColumbusFootball, self).__init__(
|
2022-08-15 17:16:18 +02:00
|
|
|
observable=observable, fps=fps, env_seed=None, **kw)
|
2022-06-30 14:42:56 +02:00
|
|
|
self.start_pos = [0.5, 0.5]
|
|
|
|
self.score = 0
|
|
|
|
self.walkingOpponents = walkingOpponent
|
|
|
|
self.flyingOpponents = flyingOpponent
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(8):
|
|
|
|
enemy = entities.CircleBarrier(self)
|
|
|
|
enemy.radius = self.random()*40+50
|
|
|
|
self.entities.append(enemy)
|
|
|
|
ball = entities.Ball(self)
|
|
|
|
self.entities.append(ball)
|
|
|
|
self.entities.append(entities.TeleportingGoal(self))
|
|
|
|
for i in range(self.walkingOpponents):
|
|
|
|
self.entities.append(entities.WalkingFootballPlayer(self, ball))
|
|
|
|
for i in range(self.flyingOpponents):
|
|
|
|
self.entities.append(entities.FlyingFootballPlayer(self, ball))
|
|
|
|
|
|
|
|
|
2022-08-27 12:03:32 +02:00
|
|
|
class ColumbusBlub(ColumbusEnv):
|
2022-09-13 22:14:17 +02:00
|
|
|
def __init__(self, observable=observables.CompositionalObservable([observables.StateObservable(), observables.RayObservable(num_rays=6, chans=[entities.Enemy])]), env_seed=None, entities=[], fps=30, **kw):
|
2022-08-27 12:03:32 +02:00
|
|
|
super().__init__(
|
2022-10-25 14:43:29 +02:00
|
|
|
observable=observable, fps=fps, env_seed=env_seed, default_collision_elasticity=0.8, speed_fac=0.01, acc_fac=0.1, agent_drag=0.06, controll_type='ACC', aux_penalty_max=1)
|
2022-08-27 12:03:32 +02:00
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
self.agent.pos = self.start_pos
|
|
|
|
for i in range(1):
|
2022-10-25 14:43:29 +02:00
|
|
|
enemy = entities.RectBarrier(self)
|
|
|
|
enemy.radius = 100
|
|
|
|
enemy.width, enemy.height = 200, 75
|
|
|
|
self.entities.append(enemy)
|
2022-08-27 12:03:32 +02:00
|
|
|
|
2022-12-06 12:16:42 +01:00
|
|
|
|
2022-06-22 20:32:17 +02:00
|
|
|
###
|
2022-12-06 12:16:42 +01:00
|
|
|
# Registering Envs fro Gym
|
2022-12-08 19:53:32 +01:00
|
|
|
register( # Legacy
|
2022-12-06 12:16:42 +01:00
|
|
|
id='ColumbusConfigDefined-v0',
|
|
|
|
entry_point=ColumbusConfigDefined,
|
|
|
|
max_episode_steps=30*60*2, # 2 min at default (30) fps
|
|
|
|
)
|
|
|
|
|
2022-12-08 19:53:32 +01:00
|
|
|
register(
|
|
|
|
id='Columbus-v1',
|
|
|
|
entry_point=ColumbusConfigDefined
|
|
|
|
)
|
|
|
|
|
|
|
|
###
|
|
|
|
|
2022-09-13 22:14:17 +02:00
|
|
|
# register(
|
|
|
|
# id='ColumbusBlub-v0',
|
|
|
|
# entry_point=ColumbusBlub,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
2022-08-27 12:03:32 +02:00
|
|
|
|
|
|
|
|
2022-12-08 19:53:32 +01:00
|
|
|
# register(
|
|
|
|
# id='ColumbusTestCnn-v0',
|
|
|
|
# entry_point=ColumbusTest3_1,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
2022-06-20 23:10:14 +02:00
|
|
|
|
2022-12-08 19:53:32 +01:00
|
|
|
# register(
|
|
|
|
# id='ColumbusTestRay-v0',
|
|
|
|
# entry_point=ColumbusTestRay,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
2022-06-20 23:10:14 +02:00
|
|
|
|
2022-09-13 22:14:17 +02:00
|
|
|
# register(
|
|
|
|
# id='ColumbusRayDrone-v0',
|
|
|
|
# entry_point=ColumbusRayDrone,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# register(
|
|
|
|
# id='ColumbusCandyland-v0',
|
|
|
|
# entry_point=ColumbusCandyland,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# register(
|
|
|
|
# id='ColumbusCandyland_Aux10-v0',
|
|
|
|
# entry_point=ColumbusCandyland_Aux10,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# register(
|
|
|
|
# id='ColumbusEasyObstacles-v0',
|
|
|
|
# entry_point=ColumbusEasyObstacles,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# register(
|
|
|
|
# id='ColumbusEasierObstacles-v0',
|
|
|
|
# entry_point=ColumbusEasyObstacles,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# register(
|
|
|
|
# id='ColumbusJustState-v0',
|
|
|
|
# entry_point=ColumbusJustState,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
2022-06-29 12:41:52 +02:00
|
|
|
|
2022-12-08 19:53:32 +01:00
|
|
|
# register(
|
|
|
|
# id='ColumbusStateWithBarriers-v0',
|
|
|
|
# entry_point=ColumbusStateWithBarriers,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
2022-06-22 20:32:17 +02:00
|
|
|
|
2022-09-13 22:14:17 +02:00
|
|
|
# register(
|
|
|
|
# id='ColumbusCompassWithBarriers-v0',
|
|
|
|
# entry_point=ColumbusCompassWithBarriers,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# register(
|
|
|
|
# id='ColumbusTrivialRay-v0',
|
|
|
|
# entry_point=ColumbusTrivialRay,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# register(
|
|
|
|
# id='ColumbusFootball-v0',
|
|
|
|
# entry_point=ColumbusFootball,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# register(
|
|
|
|
# id='ColumbusComb-v0',
|
|
|
|
# entry_point=ColumbusComp,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
|
|
|
|
|
|
|
# register(
|
|
|
|
# id='ColumbusSingle-v0',
|
|
|
|
# entry_point=ColumbusSingle,
|
|
|
|
# max_episode_steps=30*60*2,
|
|
|
|
# )
|
2022-06-30 14:42:56 +02:00
|
|
|
|
2022-08-07 18:03:27 +02:00
|
|
|
register(
|
2022-09-13 22:14:17 +02:00
|
|
|
id='ColumbusDemoEnvFootball-v0',
|
|
|
|
entry_point=ColumbusDemoEnvFootball,
|
2022-08-07 18:03:27 +02:00
|
|
|
max_episode_steps=30*60*2,
|
|
|
|
)
|
2022-08-14 17:48:53 +02:00
|
|
|
register(
|
2022-09-13 22:14:17 +02:00
|
|
|
id='ColumbusDemoEnv3_1-v0',
|
|
|
|
entry_point=ColumbusDemoEnv3_1,
|
2022-08-14 17:48:53 +02:00
|
|
|
max_episode_steps=30*60*2,
|
|
|
|
)
|
2022-08-17 19:31:15 +02:00
|
|
|
register(
|
2022-09-13 22:14:17 +02:00
|
|
|
id='ColumbusDemoEnv2_7-v0',
|
|
|
|
entry_point=ColumbusDemoEnv2_7,
|
2022-08-17 19:31:15 +02:00
|
|
|
max_episode_steps=30*60*2,
|
|
|
|
)
|