More configurability
This commit is contained in:
parent
5afa8b22b2
commit
3bb3ffa3a0
@ -19,6 +19,7 @@ class Entity(object):
|
|||||||
self.col = (255, 255, 255)
|
self.col = (255, 255, 255)
|
||||||
self.solid = False
|
self.solid = False
|
||||||
self.movable = False # False = Non movable, True = Movable, x>1: lighter movable
|
self.movable = False # False = Non movable, True = Movable, x>1: lighter movable
|
||||||
|
self.void_collidable = True
|
||||||
self.elasticity = 1
|
self.elasticity = 1
|
||||||
self.collision_changes_speed = self.env.controll_type == 'ACC'
|
self.collision_changes_speed = self.env.controll_type == 'ACC'
|
||||||
self.collision_elasticity = self.env.default_collision_elasticity
|
self.collision_elasticity = self.env.default_collision_elasticity
|
||||||
@ -31,6 +32,8 @@ class Entity(object):
|
|||||||
self.draw_path_harm = False
|
self.draw_path_harm = False
|
||||||
self.draw_path_harm_col = [c for c in self.draw_path_col]
|
self.draw_path_harm_col = [c for c in self.draw_path_col]
|
||||||
self.draw_path_harm_col[0] += int(255/3)
|
self.draw_path_harm_col[0] += int(255/3)
|
||||||
|
self.min_speed = 0
|
||||||
|
self.max_speed = math.inf
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
pass
|
pass
|
||||||
@ -40,6 +43,11 @@ class Entity(object):
|
|||||||
vx, vy = self.speed
|
vx, vy = self.speed
|
||||||
ax, ay = self.acc
|
ax, ay = self.acc
|
||||||
vx, vy = vx+ax*self.env.acc_fac, vy+ay*self.env.acc_fac
|
vx, vy = vx+ax*self.env.acc_fac, vy+ay*self.env.acc_fac
|
||||||
|
speeds = math.sqrt(vx**2 + vy**2)
|
||||||
|
if speeds < self.min_speed:
|
||||||
|
vx, vy = vx/speeds*self.min_speed, vy/speeds*self.min_speed
|
||||||
|
if speeds > self.max_speed:
|
||||||
|
vx, vy = vx/speeds*self.max_speed, vy/speeds*self.max_speed
|
||||||
x, y = x+vx*self.env.speed_fac, y+vy*self.env.speed_fac
|
x, y = x+vx*self.env.speed_fac, y+vy*self.env.speed_fac
|
||||||
if not self.env.torus_topology:
|
if not self.env.torus_topology:
|
||||||
if x > 1 or x < 0:
|
if x > 1 or x < 0:
|
||||||
@ -73,7 +81,7 @@ class Entity(object):
|
|||||||
pygame.draw.line(self.env.path_overlay, col,
|
pygame.draw.line(self.env.path_overlay, col,
|
||||||
(self.last_pos[0]*self.env.width, self.last_pos[1]*self.env.height), (self.pos[0]*self.env.width, self.pos[1]*self.env.height), self.draw_path_width)
|
(self.last_pos[0]*self.env.width, self.last_pos[1]*self.env.height), (self.pos[0]*self.env.width, self.pos[1]*self.env.height), self.draw_path_width)
|
||||||
pygame.draw.circle(self.env.path_overlay, col,
|
pygame.draw.circle(self.env.path_overlay, col,
|
||||||
(self.pos[0]*self.env.width, self.pos[1]*self.env.height), max(0, self.draw_path_width/2-1))
|
(self.pos[0]*self.env.width, self.pos[1]*self.env.height), max(0, self.draw_path_width/2-3))
|
||||||
self.last_pos = self.pos[0], self.pos[1]
|
self.last_pos = self.pos[0], self.pos[1]
|
||||||
|
|
||||||
def on_collision(self, other, depth):
|
def on_collision(self, other, depth):
|
||||||
@ -92,10 +100,17 @@ class Entity(object):
|
|||||||
return
|
return
|
||||||
force_dir = force_dir[0]/force_dir_len, force_dir[1]/force_dir_len
|
force_dir = force_dir[0]/force_dir_len, force_dir[1]/force_dir_len
|
||||||
if not self.env.torus_topology:
|
if not self.env.torus_topology:
|
||||||
if self.env.agent.pos[0] > 0.99 or self.env.agent.pos[0] < 0.01:
|
if self == self.env.agent:
|
||||||
force_dir = force_dir[0], force_dir[1] * 2
|
agent = self
|
||||||
if self.env.agent.pos[1] > 0.99 or self.env.agent.pos[1] < 0.01:
|
elif other == self.env.agent:
|
||||||
force_dir = force_dir[0] * 2, force_dir[1]
|
agent = other
|
||||||
|
else:
|
||||||
|
agent = None
|
||||||
|
if agent:
|
||||||
|
if agent.pos[0] > 0.99 or agent.pos[0] < 0.01:
|
||||||
|
force_dir = force_dir[0], force_dir[1] * 2
|
||||||
|
if agent.pos[1] > 0.99 or agent.pos[1] < 0.01:
|
||||||
|
force_dir = force_dir[0] * 2, force_dir[1]
|
||||||
depth *= 1.0*self.movable/(self.movable + other.movable)/2
|
depth *= 1.0*self.movable/(self.movable + other.movable)/2
|
||||||
depth /= other.elasticity
|
depth /= other.elasticity
|
||||||
force_vec = force_dir[0]*depth/self.env.width, \
|
force_vec = force_dir[0]*depth/self.env.width, \
|
||||||
@ -139,6 +154,24 @@ class Entity(object):
|
|||||||
def kill(self):
|
def kill(self):
|
||||||
self.env.kill_entity(self)
|
self.env.kill_entity(self)
|
||||||
|
|
||||||
|
def getQuasiRadius(self):
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
def getTop(self):
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
def getBottom(self):
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
def getLeft(self):
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
def getRight(self):
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
def getCenter(self):
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
|
||||||
class CircularEntity(Entity):
|
class CircularEntity(Entity):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
@ -208,6 +241,24 @@ class CircularEntity(Entity):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
'[!] Shape "circle" does not know how to collide with shape "'+str(other.shape)+'"')
|
'[!] Shape "circle" does not know how to collide with shape "'+str(other.shape)+'"')
|
||||||
|
|
||||||
|
def getQuasiRadius(self):
|
||||||
|
return self.radius
|
||||||
|
|
||||||
|
def getTop(self):
|
||||||
|
return self.pos[1]*self.env.height - self.radius
|
||||||
|
|
||||||
|
def getBottom(self):
|
||||||
|
return self.pos[1]*self.env.height + self.radius
|
||||||
|
|
||||||
|
def getLeft(self):
|
||||||
|
return self.pos[0]*self.env.width - self.radius
|
||||||
|
|
||||||
|
def getRight(self):
|
||||||
|
return self.pos[0]*self.env.width + self.radius
|
||||||
|
|
||||||
|
def getCenter(self):
|
||||||
|
return self.pos[0]*self.env.width, self.pos[1]*self.env.height
|
||||||
|
|
||||||
|
|
||||||
class RectangularEntity(Entity):
|
class RectangularEntity(Entity):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
@ -228,6 +279,58 @@ class RectangularEntity(Entity):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
'[!] Collisions in this direction not implemented for shape "rectangle"')
|
'[!] Collisions in this direction not implemented for shape "rectangle"')
|
||||||
|
|
||||||
|
def physics_step(self):
|
||||||
|
x, y = self.pos
|
||||||
|
vx, vy = self.speed
|
||||||
|
ax, ay = self.acc
|
||||||
|
vx, vy = vx+ax*self.env.acc_fac, vy+ay*self.env.acc_fac
|
||||||
|
speeds = math.sqrt(vx**2 + vy**2)
|
||||||
|
if speeds < self.min_speed:
|
||||||
|
vx, vy = vx/speeds*self.min_speed, vy/speeds*self.min_speed
|
||||||
|
if speeds > self.max_speed:
|
||||||
|
vx, vy = vx/speeds*self.max_speed, vy/speeds*self.max_speed
|
||||||
|
x, y = x+vx*self.env.speed_fac, y+vy*self.env.speed_fac
|
||||||
|
if not self.env.torus_topology and self.void_collidable:
|
||||||
|
if x+(self.width/self.env.width) > 1 or x < 0:
|
||||||
|
if x < 0:
|
||||||
|
x, y, vx, vy = self.calc_void_collision(
|
||||||
|
x < 0, x, y, vx, vy)
|
||||||
|
else:
|
||||||
|
x, y, vx, vy = self.calc_void_collision(
|
||||||
|
x < 0, x+(self.width/self.env.width), y, vx, vy)
|
||||||
|
x -= (self.width/self.env.width)
|
||||||
|
if y+(self.height/self.env.height) > 1 or y < 0:
|
||||||
|
if y < 0:
|
||||||
|
x, y, vx, vy = self.calc_void_collision(
|
||||||
|
2 + (x < 0), x, y, vx, vy)
|
||||||
|
else:
|
||||||
|
x, y, vx, vy = self.calc_void_collision(
|
||||||
|
2 + (x < 0), x, y+(self.height/self.env.height), vx, vy)
|
||||||
|
y -= (self.height/self.env.height)
|
||||||
|
else:
|
||||||
|
x = x % 1
|
||||||
|
y = y % 1
|
||||||
|
self.speed = vx/(1+self.drag), vy/(1+self.drag)
|
||||||
|
self.pos = x, y
|
||||||
|
|
||||||
|
def getQuasiRadius(self):
|
||||||
|
return self.width + self.height
|
||||||
|
|
||||||
|
def getTop(self):
|
||||||
|
return self.pos[1]*self.env.height
|
||||||
|
|
||||||
|
def getBottom(self):
|
||||||
|
return self.pos[1]*self.env.height + self.height
|
||||||
|
|
||||||
|
def getLeft(self):
|
||||||
|
return self.pos[0]*self.env.width
|
||||||
|
|
||||||
|
def getRight(self):
|
||||||
|
return self.pos[0]*self.env.width*self.env.height + self.width
|
||||||
|
|
||||||
|
def getCenter(self):
|
||||||
|
return self.pos[0]*self.env.width+self.width/2, self.pos[1]*self.env.height+self.height/2
|
||||||
|
|
||||||
|
|
||||||
class Agent(CircularEntity):
|
class Agent(CircularEntity):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
@ -252,6 +355,30 @@ class Agent(CircularEntity):
|
|||||||
raise Exception('Unsupported controll_type')
|
raise Exception('Unsupported controll_type')
|
||||||
|
|
||||||
|
|
||||||
|
# Does not work! Don't use!
|
||||||
|
class PongAgent(RectangularEntity):
|
||||||
|
def __init__(self, env):
|
||||||
|
super(PongAgent, self).__init__(env)
|
||||||
|
self.pos = (0.5, 0.5)
|
||||||
|
self.col = (0, 0, 255)
|
||||||
|
self.drag = self.env.agent_drag
|
||||||
|
self.controll_type = self.env.controll_type
|
||||||
|
self.solid = True
|
||||||
|
self.movable = True
|
||||||
|
|
||||||
|
def controll_step(self):
|
||||||
|
self._read_input()
|
||||||
|
self.env.check_collisions_for(self)
|
||||||
|
|
||||||
|
def _read_input(self):
|
||||||
|
if self.controll_type == 'SPEED':
|
||||||
|
self.speed = 0, self.env.inp[1] - 0.5
|
||||||
|
elif self.controll_type == 'ACC':
|
||||||
|
self.acc = 0, self.env.inp[1] - 0.5
|
||||||
|
else:
|
||||||
|
raise Exception('Unsupported controll_type')
|
||||||
|
|
||||||
|
|
||||||
class Enemy(Entity):
|
class Enemy(Entity):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super(Enemy, self).__init__(env)
|
super(Enemy, self).__init__(env)
|
||||||
@ -349,6 +476,33 @@ class Collectable(CircularEntity):
|
|||||||
self.env.check_collisions_for(self)
|
self.env.check_collisions_for(self)
|
||||||
|
|
||||||
|
|
||||||
|
class RectCollectable(RectangularEntity):
|
||||||
|
def __init__(self, env):
|
||||||
|
super(RectCollectable, self).__init__(env)
|
||||||
|
self.avaible = True
|
||||||
|
self.enforce_not_on_barrier = False
|
||||||
|
self.reward = 10
|
||||||
|
self.collectors = []
|
||||||
|
|
||||||
|
def on_collision(self, other, depth):
|
||||||
|
super().on_collision(other, depth)
|
||||||
|
if isinstance(other, Barrier):
|
||||||
|
self.on_barrier_collision()
|
||||||
|
else:
|
||||||
|
for Col in self.collectors:
|
||||||
|
if isinstance(other, Col):
|
||||||
|
other.on_collect(self)
|
||||||
|
self.on_collected()
|
||||||
|
|
||||||
|
def on_collected(self):
|
||||||
|
self.env.new_reward += self.reward
|
||||||
|
|
||||||
|
def on_barrier_collision(self):
|
||||||
|
if self.enforce_not_on_barrier:
|
||||||
|
self.pos = (self.env.random(), self.env.random())
|
||||||
|
self.env.check_collisions_for(self)
|
||||||
|
|
||||||
|
|
||||||
class Reward(Collectable):
|
class Reward(Collectable):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super(Reward, self).__init__(env)
|
super(Reward, self).__init__(env)
|
||||||
@ -479,6 +633,14 @@ class Goal(Collectable):
|
|||||||
self.collectors = [Ball]
|
self.collectors = [Ball]
|
||||||
|
|
||||||
|
|
||||||
|
class RectGoal(RectCollectable):
|
||||||
|
def __init__(self, env):
|
||||||
|
super(RectGoal, self).__init__(env)
|
||||||
|
self.col = (0, 200, 0)
|
||||||
|
self.reward = 500
|
||||||
|
self.collectors = [Ball]
|
||||||
|
|
||||||
|
|
||||||
class TeleportingGoal(Goal):
|
class TeleportingGoal(Goal):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super(TeleportingGoal, self).__init__(env)
|
super(TeleportingGoal, self).__init__(env)
|
||||||
|
@ -15,7 +15,7 @@ from columbus.utils import soft_int, parseObs
|
|||||||
class ColumbusEnv(gym.Env):
|
class ColumbusEnv(gym.Env):
|
||||||
metadata = {'render.modes': ['human']}
|
metadata = {'render.modes': ['human']}
|
||||||
|
|
||||||
def __init__(self, observable=observables.Observable(), fps=60, env_seed=3.1, master_seed=None, start_pos=(0.5, 0.5), start_score=0, speed_fac=0.01, acc_fac=0.04, die_on_zero=False, return_on_score=-1, reward_mult=1, agent_drag=0, controll_type='SPEED', aux_reward_max=1, aux_penalty_max=0, aux_reward_discretize=0, void_is_type_barrier=True, void_damage=1, torus_topology=False, default_collision_elasticity=1, terminate_on_reward=False, agent_draw_path=False, clear_path_on_reset=True, max_steps=-1, value_color_mapper='tanh', width=720, height=720, agent_attrs={}):
|
def __init__(self, observable=observables.Observable(), fps=60, env_seed=3.1, master_seed=None, start_pos=(0.5, 0.5), start_score=0, speed_fac=0.01, acc_fac=0.04, die_on_zero=False, return_on_score=-1, reward_mult=1, agent_drag=0, controll_type='SPEED', aux_reward_max=1, aux_penalty_max=0, aux_reward_discretize=0, void_is_type_barrier=True, void_damage=1, torus_topology=False, default_collision_elasticity=1, terminate_on_reward=False, agent_draw_path=False, clear_path_on_reset=True, max_steps=-1, value_color_mapper='tanh', width=720, height=720, agent_attrs={}, agent_cls=entities.Agent, exception_for_unsupported_collision=True):
|
||||||
super(ColumbusEnv, self).__init__()
|
super(ColumbusEnv, self).__init__()
|
||||||
self.action_space = spaces.Box(
|
self.action_space = spaces.Box(
|
||||||
low=-1, high=1, shape=(2,), dtype=np.float32)
|
low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||||
@ -63,8 +63,13 @@ class ColumbusEnv(gym.Env):
|
|||||||
self.clear_path_on_reset = clear_path_on_reset
|
self.clear_path_on_reset = clear_path_on_reset
|
||||||
self.path_decay = 0.1
|
self.path_decay = 0.1
|
||||||
|
|
||||||
|
if isinstance(agent_cls, str):
|
||||||
|
agent_cls = getattr(entities, agent_cls)
|
||||||
|
self.Agent_cls = agent_cls
|
||||||
self.agent_attrs = agent_attrs
|
self.agent_attrs = agent_attrs
|
||||||
|
|
||||||
|
self.exception_for_unsupported_collision = exception_for_unsupported_collision
|
||||||
|
|
||||||
if value_color_mapper == 'atan':
|
if value_color_mapper == 'atan':
|
||||||
def value_color_mapper(x): return th.atan(x*2)/0.786/2
|
def value_color_mapper(x): return th.atan(x*2)/0.786/2
|
||||||
elif value_color_mapper == 'tanh':
|
elif value_color_mapper == 'tanh':
|
||||||
@ -217,8 +222,8 @@ class ColumbusEnv(gym.Env):
|
|||||||
reward, self.new_reward, self.new_abs_reward = self.new_reward / \
|
reward, self.new_reward, self.new_abs_reward = self.new_reward / \
|
||||||
self.fps + self.new_abs_reward, 0, 0
|
self.fps + self.new_abs_reward, 0, 0
|
||||||
if not self.torus_topology:
|
if not self.torus_topology:
|
||||||
if self.agent.pos[0] < 0.001 or self.agent.pos[0] > 0.999 \
|
if self.agent.getTop() < 1 or self.agent.getBottom() > self.height-1 \
|
||||||
or self.agent.pos[1] < 0.001 or self.agent.pos[1] > 0.999:
|
or self.agent.getLeft() < 1 or self.agent.getRight() > self.width-1:
|
||||||
reward -= self.void_damage/self.fps
|
reward -= self.void_damage/self.fps
|
||||||
self.score += reward # aux_reward does not count towards the score
|
self.score += reward # aux_reward does not count towards the score
|
||||||
if self.aux_reward_max or self.aux_penalty_max:
|
if self.aux_reward_max or self.aux_penalty_max:
|
||||||
@ -252,8 +257,10 @@ class ColumbusEnv(gym.Env):
|
|||||||
elif shapes == ['circle', 'rect']:
|
elif shapes == ['circle', 'rect']:
|
||||||
return sum([abs(d) for d in e1._get_crash_force_dir(e2)])
|
return sum([abs(d) for d in e1._get_crash_force_dir(e2)])
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
if self.exception_for_unsupported_collision:
|
||||||
'Checking for collision between unsupported shapes: '+str(shapes))
|
raise Exception(
|
||||||
|
'Checking for collision between unsupported shapes: '+str(shapes))
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def kill_entity(self, target):
|
def kill_entity(self, target):
|
||||||
newEntities = []
|
newEntities = []
|
||||||
@ -270,7 +277,7 @@ class ColumbusEnv(gym.Env):
|
|||||||
# Expand this function
|
# Expand this function
|
||||||
|
|
||||||
def _spawnAgent(self):
|
def _spawnAgent(self):
|
||||||
self.agent = entities.Agent(self)
|
self.agent = self.Agent_cls(self)
|
||||||
self.agent.draw_path = self.agent_draw_path
|
self.agent.draw_path = self.agent_draw_path
|
||||||
for k, v in self.agent_attrs.items():
|
for k, v in self.agent_attrs.items():
|
||||||
setattr(self.agent, k, v)
|
setattr(self.agent, k, v)
|
||||||
@ -497,11 +504,11 @@ class ColumbusConfigDefined(ColumbusEnv):
|
|||||||
def is_unit(self, s):
|
def is_unit(self, s):
|
||||||
if type(s) in [int, float]:
|
if type(s) in [int, float]:
|
||||||
return True
|
return True
|
||||||
if s.replace('.', '', 1).isdigit():
|
if s.replace('.', '', 1).replace('-', '0', 1).isdigit():
|
||||||
return True
|
return True
|
||||||
num, unit = s[:-2], s[-2:]
|
num, unit = s[:-2], s[-2:]
|
||||||
if unit in ['px', 'em', 'rx', 'ry', 'ct', 'au']:
|
if unit in ['px', 'em', 'rx', 'ry', 'ct', 'au']:
|
||||||
if num.replace('.', '', 1).isdigit():
|
if num.replace('.', '', 1).replace('-', '0', 1).isdigit():
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -556,11 +563,32 @@ class ColumbusConfigDefined(ColumbusEnv):
|
|||||||
else:
|
else:
|
||||||
v = v_raw
|
v = v_raw
|
||||||
if k.endswith('_rand'):
|
if k.endswith('_rand'):
|
||||||
n = k.replace('_rand', '')
|
if isinstance(v, int):
|
||||||
cur = getattr(
|
n = k.replace('_rand', '')
|
||||||
entity, n)
|
cur = getattr(
|
||||||
inc = int((v+0.99)*self.random())
|
entity, n)
|
||||||
setattr(entity, n, cur + inc)
|
inc = int((v+0.99)*self.random())
|
||||||
|
setattr(entity, n, cur + inc)
|
||||||
|
elif isinstance(v, float):
|
||||||
|
n = k.replace('_randf', '')
|
||||||
|
cur = getattr(
|
||||||
|
entity, n)
|
||||||
|
inc = v*self.random()
|
||||||
|
setattr(entity, n, cur + inc)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
for vi, ve in enumerate(v):
|
||||||
|
if isinstance(v, int):
|
||||||
|
n = k.replace('_rand', '')
|
||||||
|
cur = getattr(
|
||||||
|
entity, n)
|
||||||
|
cur[vi] = int((v+0.99)*self.random())
|
||||||
|
setattr(entity, n, cur)
|
||||||
|
elif isinstance(v, float):
|
||||||
|
n = k.replace('_randf', '')
|
||||||
|
cur = getattr(
|
||||||
|
entity, n)
|
||||||
|
cur[vi] = v*self.random()
|
||||||
|
setattr(entity, n, cur)
|
||||||
elif k.endswith('_randf'):
|
elif k.endswith('_randf'):
|
||||||
n = k.replace('_randf', '')
|
n = k.replace('_randf', '')
|
||||||
cur = getattr(
|
cur = getattr(
|
||||||
|
@ -151,7 +151,7 @@ class RayObservable(Observable):
|
|||||||
'Can only raycast circular and rectangular entities!')
|
'Can only raycast circular and rectangular entities!')
|
||||||
sq_dist = ((self.env.agent.pos[0]-x)*self.env.width) ** 2 \
|
sq_dist = ((self.env.agent.pos[0]-x)*self.env.width) ** 2 \
|
||||||
+ ((self.env.agent.pos[1]-y)*self.env.height) ** 2
|
+ ((self.env.agent.pos[1]-y)*self.env.height) ** 2
|
||||||
if sq_dist <= (radius + self.env.agent.radius + self.ray_len)**2:
|
if sq_dist <= (radius + self.env.agent.getQuasiRadius() + self.ray_len)**2:
|
||||||
entities_l.append(entity) # cannot use yield here!
|
entities_l.append(entity) # cannot use yield here!
|
||||||
return entities_l
|
return entities_l
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user