More precise aux_penalties

This commit is contained in:
Dominik Moritz Roth 2022-10-25 14:43:29 +02:00
parent 71809b7374
commit 910acc2a15

View File

@ -78,7 +78,7 @@ class ColumbusEnv(gym.Env):
self.aux_penalty_max = aux_penalty_max # 0 = off
self.aux_reward_discretize = aux_reward_discretize
# 0 = dont discretize; how many steps (along diagonal)
self.penalty_from_edges = False # not ready yet...
self.penalty_from_edges = True # Don't change, only here to allow legacy behavior
self.draw_observable = True
self.draw_joystick = True
self.draw_entities = True
@ -182,9 +182,32 @@ class ColumbusEnv(gym.Env):
elif isinstance(entity, entities.Enemy):
if entity.radiateDamage:
if self.penalty_from_edges:
penalty = self.aux_penalty_max / \
(1 + self.sq_dist(entity.pos,
self.agent.pos) - entity.radius - self.agent.radius)
if self.agent.shape != 'circle':
raise Exception(
'Radiating damage from edge for non-circle Agents not supported')
if entity.shape == 'circle':
penalty = self.aux_penalty_max / \
(1 + self.sq_dist(entity.pos,
self.agent.pos) - (entity.radius/max(self.height, self.width))**2 - (self.agent.radius/max(self.height, self.width))**2)
elif entity.shape == 'rect':
ax, ay = self.agent.pos
ex, ey, ex2, ey2 = entity.pos[0], entity.pos[1], entity.pos[0] + \
entity.width / \
self.width, entity.pos[1] + \
entity.height/self.height
lx, ly = ax, ay # 'Lotpunkt'
if ax < ex:
lx = ex
elif ax > ex2:
lx = ex2
if ay < ey:
ly = ey
elif ay > ey2:
ly = ey2
penalty = self.aux_penalty_max / \
(1 + self.sq_dist((lx, ly),
(ax, ay)) - (self.agent.radius/max(self.height, self.width))**2)
else:
penalty = self.aux_penalty_max / \
(1 + self.sq_dist(entity.pos, self.agent.pos))
@ -845,20 +868,15 @@ class ColumbusConfigDefined(ColumbusEnv):
class ColumbusBlub(ColumbusEnv):
def __init__(self, observable=observables.CompositionalObservable([observables.StateObservable(), observables.RayObservable(num_rays=6, chans=[entities.Enemy])]), env_seed=None, entities=[], fps=30, **kw):
super().__init__(
observable=observable, fps=fps, env_seed=env_seed, default_collision_elasticity=0.8, speed_fac=0.01, acc_fac=0.1, agent_drag=0.06, controll_type='ACC')
observable=observable, fps=fps, env_seed=env_seed, default_collision_elasticity=0.8, speed_fac=0.01, acc_fac=0.1, agent_drag=0.06, controll_type='ACC', aux_penalty_max=1)
def setup(self):
self.agent.pos = self.start_pos
for i in range(10):
enemy = entities.CircleBarrier(self)
enemy.radius = self.random()*25+75
self.entities.append(enemy)
for i in range(1):
reward = entities.TeleportingReward(self)
reward.radius = 20
reward.reward = 25
self.entities.append(reward)
enemy = entities.RectBarrier(self)
enemy.radius = 100
enemy.width, enemy.height = 200, 75
self.entities.append(enemy)
###
# register(