Minor cleanup and reorganized some class definitions

This commit is contained in:
Dominik Moritz Roth 2022-12-06 12:16:42 +01:00
parent 8132bb9321
commit 6d465c69c9

View File

@ -11,6 +11,7 @@ import torch as th
def parseObs(obsConf): def parseObs(obsConf):
# Parsing Observable Definitions
if type(obsConf) == list: if type(obsConf) == list:
obs = [] obs = []
for i, c in enumerate(obsConf): for i, c in enumerate(obsConf):
@ -494,6 +495,96 @@ class ColumbusEnv(gym.Env):
pygame.quit() pygame.quit()
class ColumbusConfigDefined(ColumbusEnv):
# Allows defining Columbus Environments using dicts.
# Intended to be used in combination with cw2 configuration.
# Look into humanPlayer to see how this is supposed to be interfaced with.
def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw):
super().__init__(
observable=observable, fps=fps, env_seed=env_seed, **kw)
self.entities_definitions = entities
def is_unit(self, s):
if type(s) in [int, float]:
return True
if s.replace('.', '', 1).isdigit():
return True
num, unit = s[:-2], s[-2:]
if unit in ['px', 'em', 'rx', 'ry', 'ct']:
if num.replace('.', '', 1).isdigit():
return True
return False
def conv_unit(self, s, target='px', axis='x'):
assert self.is_unit(s)
if type(s) in [int, float]:
return s
if s.replace('.', '', 1).isdigit():
if target == 'px':
return int(s)
return float(s)
num, unit = s[:-2], s[-2:]
num = float(num)
if unit == 'rx':
unit = 'px'
axis = 'x'
elif unit == 'ry':
unit = 'px'
axis = 'y'
if unit == 'em':
em = num
elif unit == 'px':
em = num / ({'x': self.width, 'y': self.height}[axis])
elif unit == 'ct':
em = num / 100
else:
raise Exception('Conversion not implemented')
if target == 'em':
return em
elif target == 'px':
return int(em * ({'x': self.width, 'y': self.height}[axis]))
def setup(self):
self.agent.pos = self.start_pos
for i, e in enumerate(self.entities_definitions):
Entity = getattr(entities, e['type'])
for i in range(e.get('num', 1) + int(self.random()*(0.99+e.get('num_rand', 0)))):
entity = Entity(self)
conf = {k: v for k, v in e.items() if str(
k) not in ['num', 'num_rand', 'type']}
for k, v_raw in conf.items():
if k == 'pos':
v = self.conv_unit(v_raw[0], target='em', axis='x'), self.conv_unit(
v_raw[1], target='em', axis='y')
elif k in ['width', 'height', 'radius']:
v = self.conv_unit(
v_raw, target='px', axis='y' if k == 'height' else 'x')
else:
v = v_raw
if k.endswith('_rand'):
n = k.replace('_rand', '')
cur = getattr(
entity, n)
inc = int((v+0.99)*self.random())
setattr(entity, n, cur + inc)
elif k.endswith('_randf'):
n = k.replace('_randf', '')
cur = getattr(
entity, n)
inc = v*self.random()
setattr(entity, n, cur + inc)
else:
setattr(entity, k, v)
self.entities.append(entity)
###
# Custom Env Definitions
class ColumbusTest3_1(ColumbusEnv): class ColumbusTest3_1(ColumbusEnv):
def __init__(self, observable=observables.CnnObservable(out_width=48, out_height=48), fps=30, aux_reward_max=1, **kw): def __init__(self, observable=observables.CnnObservable(out_width=48, out_height=48), fps=30, aux_reward_max=1, **kw):
super(ColumbusTest3_1, self).__init__( super(ColumbusTest3_1, self).__init__(
@ -833,89 +924,6 @@ class ColumbusFootball(ColumbusEnv):
self.entities.append(entities.FlyingFootballPlayer(self, ball)) self.entities.append(entities.FlyingFootballPlayer(self, ball))
class ColumbusConfigDefined(ColumbusEnv):
def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw):
super().__init__(
observable=observable, fps=fps, env_seed=env_seed, **kw)
self.entities_definitions = entities
def is_unit(self, s):
if type(s) in [int, float]:
return True
if s.replace('.', '', 1).isdigit():
return True
num, unit = s[:-2], s[-2:]
if unit in ['px', 'em', 'rx', 'ry', 'ct']:
if num.replace('.', '', 1).isdigit():
return True
return False
def conv_unit(self, s, target='px', axis='x'):
assert self.is_unit(s)
if type(s) in [int, float]:
return s
if s.replace('.', '', 1).isdigit():
if target == 'px':
return int(s)
return float(s)
num, unit = s[:-2], s[-2:]
num = float(num)
if unit == 'rx':
unit = 'px'
axis = 'x'
elif unit == 'ry':
unit = 'px'
axis = 'y'
if unit == 'em':
em = num
elif unit == 'px':
em = num / ({'x': self.width, 'y': self.height}[axis])
elif unit == 'ct':
em = num / 100
else:
raise Exception('Conversion not implemented')
if target == 'em':
return em
elif target == 'px':
return int(em * ({'x': self.width, 'y': self.height}[axis]))
def setup(self):
self.agent.pos = self.start_pos
for i, e in enumerate(self.entities_definitions):
Entity = getattr(entities, e['type'])
for i in range(e.get('num', 1) + int(self.random()*(0.99+e.get('num_rand', 0)))):
entity = Entity(self)
conf = {k: v for k, v in e.items() if str(
k) not in ['num', 'num_rand', 'type']}
for k, v_raw in conf.items():
if k == 'pos':
v = self.conv_unit(v_raw[0], target='em', axis='x'), self.conv_unit(
v_raw[1], target='em', axis='y')
elif k in ['width', 'height', 'radius']:
v = self.conv_unit(
v_raw, target='px', axis='y' if k == 'height' else 'x')
else:
v = v_raw
if k.endswith('_rand'):
n = k.replace('_rand', '')
cur = getattr(
entity, n)
inc = int((v+0.99)*self.random())
setattr(entity, n, cur + inc)
elif k.endswith('_randf'):
n = k.replace('_randf', '')
cur = getattr(
entity, n)
inc = v*self.random()
setattr(entity, n, cur + inc)
else:
setattr(entity, k, v)
self.entities.append(entity)
class ColumbusBlub(ColumbusEnv): class ColumbusBlub(ColumbusEnv):
def __init__(self, observable=observables.CompositionalObservable([observables.StateObservable(), observables.RayObservable(num_rays=6, chans=[entities.Enemy])]), env_seed=None, entities=[], fps=30, **kw): def __init__(self, observable=observables.CompositionalObservable([observables.StateObservable(), observables.RayObservable(num_rays=6, chans=[entities.Enemy])]), env_seed=None, entities=[], fps=30, **kw):
super().__init__( super().__init__(
@ -929,7 +937,15 @@ class ColumbusBlub(ColumbusEnv):
enemy.width, enemy.height = 200, 75 enemy.width, enemy.height = 200, 75
self.entities.append(enemy) self.entities.append(enemy)
### ###
# Registering Envs fro Gym
register(
id='ColumbusConfigDefined-v0',
entry_point=ColumbusConfigDefined,
max_episode_steps=30*60*2, # 2 min at default (30) fps
)
# register( # register(
# id='ColumbusBlub-v0', # id='ColumbusBlub-v0',
# entry_point=ColumbusBlub, # entry_point=ColumbusBlub,
@ -1021,12 +1037,6 @@ register(
# max_episode_steps=30*60*2, # max_episode_steps=30*60*2,
# ) # )
register(
id='ColumbusConfigDefined-v0',
entry_point=ColumbusConfigDefined,
max_episode_steps=30*60*2,
)
register( register(
id='ColumbusDemoEnvFootball-v0', id='ColumbusDemoEnvFootball-v0',
entry_point=ColumbusDemoEnvFootball, entry_point=ColumbusDemoEnvFootball,