Better path visualizations; more agent & path configs
This commit is contained in:
parent
78ac536bb9
commit
5afa8b22b2
@ -28,6 +28,9 @@ class Entity(object):
|
|||||||
self.draw_path = False
|
self.draw_path = False
|
||||||
self.draw_path_col = [int(c/5) for c in self.col]
|
self.draw_path_col = [int(c/5) for c in self.col]
|
||||||
self.draw_path_width = 2
|
self.draw_path_width = 2
|
||||||
|
self.draw_path_harm = False
|
||||||
|
self.draw_path_harm_col = [c for c in self.draw_path_col]
|
||||||
|
self.draw_path_harm_col[0] += int(255/3)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
pass
|
pass
|
||||||
@ -63,8 +66,14 @@ class Entity(object):
|
|||||||
|
|
||||||
def _draw_path(self):
|
def _draw_path(self):
|
||||||
if self.draw_path and self.last_pos:
|
if self.draw_path and self.last_pos:
|
||||||
pygame.draw.line(self.env.path_overlay, self.draw_path_col,
|
col = self.draw_path_col
|
||||||
|
if self.draw_path_harm:
|
||||||
|
if self.env.gotHarm:
|
||||||
|
col = self.draw_path_harm_col
|
||||||
|
pygame.draw.line(self.env.path_overlay, col,
|
||||||
(self.last_pos[0]*self.env.width, self.last_pos[1]*self.env.height), (self.pos[0]*self.env.width, self.pos[1]*self.env.height), self.draw_path_width)
|
(self.last_pos[0]*self.env.width, self.last_pos[1]*self.env.height), (self.pos[0]*self.env.width, self.pos[1]*self.env.height), self.draw_path_width)
|
||||||
|
pygame.draw.circle(self.env.path_overlay, col,
|
||||||
|
(self.pos[0]*self.env.width, self.pos[1]*self.env.height), max(0, self.draw_path_width/2-1))
|
||||||
self.last_pos = self.pos[0], self.pos[1]
|
self.last_pos = self.pos[0], self.pos[1]
|
||||||
|
|
||||||
def on_collision(self, other, depth):
|
def on_collision(self, other, depth):
|
||||||
|
@ -6,47 +6,16 @@ import pygame
|
|||||||
import random as random_dont_use
|
import random as random_dont_use
|
||||||
from os import urandom
|
from os import urandom
|
||||||
import math
|
import math
|
||||||
from columbus import entities, observables
|
|
||||||
import torch as th
|
import torch as th
|
||||||
|
|
||||||
|
from columbus import entities, observables
|
||||||
def parseObs(obsConf):
|
from columbus.utils import soft_int, parseObs
|
||||||
# Parsing Observable Definitions
|
|
||||||
if type(obsConf) == list:
|
|
||||||
obs = []
|
|
||||||
for i, c in enumerate(obsConf):
|
|
||||||
obs.append(parseObs(c))
|
|
||||||
if len(obs) == 1:
|
|
||||||
return obs[0]
|
|
||||||
else:
|
|
||||||
return observables.CompositionalObservable(obs)
|
|
||||||
|
|
||||||
if obsConf['type'] == 'State':
|
|
||||||
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
|
||||||
return observables.StateObservable(**conf)
|
|
||||||
elif obsConf['type'] == 'Compass':
|
|
||||||
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
|
||||||
return observables.CompassObservable(**conf)
|
|
||||||
elif obsConf['type'] == 'RayCast':
|
|
||||||
chans = []
|
|
||||||
for chan in obsConf.get('chans', []):
|
|
||||||
chans.append(getattr(entities, chan))
|
|
||||||
conf = {k: v for k, v in obsConf.items() if k not in ['type', 'chans']}
|
|
||||||
return observables.RayObservable(chans=chans, **conf)
|
|
||||||
elif obsConf['type'] == 'CNN':
|
|
||||||
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
|
||||||
return observables.CnnObservable(**conf)
|
|
||||||
elif obsConf['type'] == 'Dummy':
|
|
||||||
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
|
||||||
return observables.Observable(**conf)
|
|
||||||
else:
|
|
||||||
raise Exception('Unknown Observable selected')
|
|
||||||
|
|
||||||
|
|
||||||
class ColumbusEnv(gym.Env):
|
class ColumbusEnv(gym.Env):
|
||||||
metadata = {'render.modes': ['human']}
|
metadata = {'render.modes': ['human']}
|
||||||
|
|
||||||
def __init__(self, observable=observables.Observable(), fps=60, env_seed=3.1, master_seed=None, start_pos=(0.5, 0.5), start_score=0, speed_fac=0.01, acc_fac=0.04, die_on_zero=False, return_on_score=-1, reward_mult=1, agent_drag=0, controll_type='SPEED', aux_reward_max=1, aux_penalty_max=0, aux_reward_discretize=0, void_is_type_barrier=True, void_damage=1, torus_topology=False, default_collision_elasticity=1, terminate_on_reward=False, agent_draw_path=False, clear_path_on_reset=True, max_steps=-1, value_color_mapper='tanh', width=720, height=720):
|
def __init__(self, observable=observables.Observable(), fps=60, env_seed=3.1, master_seed=None, start_pos=(0.5, 0.5), start_score=0, speed_fac=0.01, acc_fac=0.04, die_on_zero=False, return_on_score=-1, reward_mult=1, agent_drag=0, controll_type='SPEED', aux_reward_max=1, aux_penalty_max=0, aux_reward_discretize=0, void_is_type_barrier=True, void_damage=1, torus_topology=False, default_collision_elasticity=1, terminate_on_reward=False, agent_draw_path=False, clear_path_on_reset=True, max_steps=-1, value_color_mapper='tanh', width=720, height=720, agent_attrs={}):
|
||||||
super(ColumbusEnv, self).__init__()
|
super(ColumbusEnv, self).__init__()
|
||||||
self.action_space = spaces.Box(
|
self.action_space = spaces.Box(
|
||||||
low=-1, high=1, shape=(2,), dtype=np.float32)
|
low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||||
@ -92,6 +61,9 @@ class ColumbusEnv(gym.Env):
|
|||||||
self.terminate_on_reward = terminate_on_reward
|
self.terminate_on_reward = terminate_on_reward
|
||||||
self.agent_draw_path = agent_draw_path
|
self.agent_draw_path = agent_draw_path
|
||||||
self.clear_path_on_reset = clear_path_on_reset
|
self.clear_path_on_reset = clear_path_on_reset
|
||||||
|
self.path_decay = 0.1
|
||||||
|
|
||||||
|
self.agent_attrs = agent_attrs
|
||||||
|
|
||||||
if value_color_mapper == 'atan':
|
if value_color_mapper == 'atan':
|
||||||
def value_color_mapper(x): return th.atan(x*2)/0.786/2
|
def value_color_mapper(x): return th.atan(x*2)/0.786/2
|
||||||
@ -241,6 +213,7 @@ class ColumbusEnv(gym.Env):
|
|||||||
self._step_entities()
|
self._step_entities()
|
||||||
observation = self.observable.get_observation()
|
observation = self.observable.get_observation()
|
||||||
gotRew = self.new_reward > 0 or self.new_abs_reward > 0
|
gotRew = self.new_reward > 0 or self.new_abs_reward > 0
|
||||||
|
self.gotHarm = self.new_reward < 0 or self.new_abs_reward < 0
|
||||||
reward, self.new_reward, self.new_abs_reward = self.new_reward / \
|
reward, self.new_reward, self.new_abs_reward = self.new_reward / \
|
||||||
self.fps + self.new_abs_reward, 0, 0
|
self.fps + self.new_abs_reward, 0, 0
|
||||||
if not self.torus_topology:
|
if not self.torus_topology:
|
||||||
@ -296,6 +269,12 @@ class ColumbusEnv(gym.Env):
|
|||||||
self.agent.pos = self.start_pos
|
self.agent.pos = self.start_pos
|
||||||
# Expand this function
|
# Expand this function
|
||||||
|
|
||||||
|
def _spawnAgent(self):
|
||||||
|
self.agent = entities.Agent(self)
|
||||||
|
self.agent.draw_path = self.agent_draw_path
|
||||||
|
for k, v in self.agent_attrs.items():
|
||||||
|
setattr(self.agent, k, v)
|
||||||
|
|
||||||
def reset(self, force_reset_path=False):
|
def reset(self, force_reset_path=False):
|
||||||
pygame.init()
|
pygame.init()
|
||||||
self._init = True
|
self._init = True
|
||||||
@ -308,11 +287,11 @@ class ColumbusEnv(gym.Env):
|
|||||||
# will get rescaled acording to fps (=reward per second)
|
# will get rescaled acording to fps (=reward per second)
|
||||||
self.new_reward = 0
|
self.new_reward = 0
|
||||||
self.new_abs_reward = 0 # will not get rescaled. should be used for one-time rewards
|
self.new_abs_reward = 0 # will not get rescaled. should be used for one-time rewards
|
||||||
|
self.gotHarm = False
|
||||||
self.score = self.start_score
|
self.score = self.start_score
|
||||||
self.entities = []
|
self.entities = []
|
||||||
self.timers = []
|
self.timers = []
|
||||||
self.agent = entities.Agent(self)
|
self._spawnAgent()
|
||||||
self.agent.draw_path = self.agent_draw_path
|
|
||||||
self.setup()
|
self.setup()
|
||||||
self.entities.append(self.agent) # add it last, will be drawn on top
|
self.entities.append(self.agent) # add it last, will be drawn on top
|
||||||
self.observable.reset()
|
self.observable.reset()
|
||||||
@ -428,6 +407,14 @@ class ColumbusEnv(gym.Env):
|
|||||||
self.screen.blit(rotated_surf, rotated_surf.get_rect(
|
self.screen.blit(rotated_surf, rotated_surf.get_rect(
|
||||||
center=rect.center))
|
center=rect.center))
|
||||||
|
|
||||||
|
def _draw_paths(self):
|
||||||
|
if self.path_decay != 0.0:
|
||||||
|
s = pygame.Surface((self.width, self.height))
|
||||||
|
s.set_alpha(soft_int(255*self.path_decay/self.fps))
|
||||||
|
s.fill((0, 0, 0))
|
||||||
|
self.path_overlay.blit(s, (0, 0))
|
||||||
|
self.surface.blit(self.path_overlay, (0, 0))
|
||||||
|
|
||||||
def _handle_user_input(self):
|
def _handle_user_input(self):
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
pass
|
pass
|
||||||
@ -472,7 +459,7 @@ class ColumbusEnv(gym.Env):
|
|||||||
if value_func != None:
|
if value_func != None:
|
||||||
self._draw_values(value_func, values_static,
|
self._draw_values(value_func, values_static,
|
||||||
color_mapper=self.value_color_mapper)
|
color_mapper=self.value_color_mapper)
|
||||||
self.surface.blit(self.path_overlay, (0, 0))
|
self._draw_paths()
|
||||||
if self.draw_entities:
|
if self.draw_entities:
|
||||||
self._draw_entities()
|
self._draw_entities()
|
||||||
else:
|
else:
|
||||||
|
@ -132,6 +132,8 @@ class RayObservable(Observable):
|
|||||||
'Can only raycast circular and rectangular entities!')
|
'Can only raycast circular and rectangular entities!')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Filter out entities, that we sure are out of range
|
||||||
|
# (so we have to do less work for the ray collisions)
|
||||||
def _get_possible_entities(self):
|
def _get_possible_entities(self):
|
||||||
entities_l = []
|
entities_l = []
|
||||||
if entities.Void in self.chans or self.env.void_barrier:
|
if entities.Void in self.chans or self.env.void_barrier:
|
||||||
@ -153,6 +155,8 @@ class RayObservable(Observable):
|
|||||||
entities_l.append(entity) # cannot use yield here!
|
entities_l.append(entity) # cannot use yield here!
|
||||||
return entities_l
|
return entities_l
|
||||||
|
|
||||||
|
# Ugly, inefficient ray casting
|
||||||
|
# Oh well, it works...
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
entities = self._get_possible_entities()
|
entities = self._get_possible_entities()
|
||||||
self.rays = np.zeros((self.num_rays+self.include_rand, self.num_chans))
|
self.rays = np.zeros((self.num_rays+self.include_rand, self.num_chans))
|
||||||
|
42
columbus/utils.py
Normal file
42
columbus/utils.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from columbus import entities, observables
|
||||||
|
|
||||||
|
import random as random_dont_use
|
||||||
|
|
||||||
|
|
||||||
|
def parseObs(obsConf):
|
||||||
|
# Parsing Observable Definitions
|
||||||
|
if type(obsConf) == list:
|
||||||
|
obs = []
|
||||||
|
for i, c in enumerate(obsConf):
|
||||||
|
obs.append(parseObs(c))
|
||||||
|
if len(obs) == 1:
|
||||||
|
return obs[0]
|
||||||
|
else:
|
||||||
|
return observables.CompositionalObservable(obs)
|
||||||
|
|
||||||
|
if obsConf['type'] == 'State':
|
||||||
|
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
||||||
|
return observables.StateObservable(**conf)
|
||||||
|
elif obsConf['type'] == 'Compass':
|
||||||
|
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
||||||
|
return observables.CompassObservable(**conf)
|
||||||
|
elif obsConf['type'] == 'RayCast':
|
||||||
|
chans = []
|
||||||
|
for chan in obsConf.get('chans', []):
|
||||||
|
chans.append(getattr(entities, chan))
|
||||||
|
conf = {k: v for k, v in obsConf.items() if k not in ['type', 'chans']}
|
||||||
|
return observables.RayObservable(chans=chans, **conf)
|
||||||
|
elif obsConf['type'] == 'CNN':
|
||||||
|
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
||||||
|
return observables.CnnObservable(**conf)
|
||||||
|
elif obsConf['type'] == 'Dummy':
|
||||||
|
conf = {k: v for k, v in obsConf.items() if k not in ['type']}
|
||||||
|
return observables.Observable(**conf)
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown Observable selected')
|
||||||
|
|
||||||
|
|
||||||
|
def soft_int(num):
|
||||||
|
i = int(num)
|
||||||
|
r = num - i
|
||||||
|
return i + int(random_dont_use.random() < r)
|
Loading…
Reference in New Issue
Block a user