Minor cleanup and reorganized some class definitions
This commit is contained in:
parent
8132bb9321
commit
6d465c69c9
188
columbus/env.py
188
columbus/env.py
@ -11,6 +11,7 @@ import torch as th
|
||||
|
||||
|
||||
def parseObs(obsConf):
|
||||
# Parsing Observable Definitions
|
||||
if type(obsConf) == list:
|
||||
obs = []
|
||||
for i, c in enumerate(obsConf):
|
||||
@ -494,6 +495,96 @@ class ColumbusEnv(gym.Env):
|
||||
pygame.quit()
|
||||
|
||||
|
||||
class ColumbusConfigDefined(ColumbusEnv):
|
||||
# Allows defining Columbus Environments using dicts.
|
||||
# Intended to be used in combination with cw2 configuration.
|
||||
# Look into humanPlayer to see how this is supposed to be interfaced with.
|
||||
|
||||
def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw):
|
||||
super().__init__(
|
||||
observable=observable, fps=fps, env_seed=env_seed, **kw)
|
||||
self.entities_definitions = entities
|
||||
|
||||
def is_unit(self, s):
|
||||
if type(s) in [int, float]:
|
||||
return True
|
||||
if s.replace('.', '', 1).isdigit():
|
||||
return True
|
||||
num, unit = s[:-2], s[-2:]
|
||||
if unit in ['px', 'em', 'rx', 'ry', 'ct']:
|
||||
if num.replace('.', '', 1).isdigit():
|
||||
return True
|
||||
return False
|
||||
|
||||
def conv_unit(self, s, target='px', axis='x'):
|
||||
assert self.is_unit(s)
|
||||
if type(s) in [int, float]:
|
||||
return s
|
||||
if s.replace('.', '', 1).isdigit():
|
||||
if target == 'px':
|
||||
return int(s)
|
||||
return float(s)
|
||||
num, unit = s[:-2], s[-2:]
|
||||
num = float(num)
|
||||
if unit == 'rx':
|
||||
unit = 'px'
|
||||
axis = 'x'
|
||||
elif unit == 'ry':
|
||||
unit = 'px'
|
||||
axis = 'y'
|
||||
if unit == 'em':
|
||||
em = num
|
||||
elif unit == 'px':
|
||||
em = num / ({'x': self.width, 'y': self.height}[axis])
|
||||
elif unit == 'ct':
|
||||
em = num / 100
|
||||
else:
|
||||
raise Exception('Conversion not implemented')
|
||||
|
||||
if target == 'em':
|
||||
return em
|
||||
elif target == 'px':
|
||||
return int(em * ({'x': self.width, 'y': self.height}[axis]))
|
||||
|
||||
def setup(self):
|
||||
self.agent.pos = self.start_pos
|
||||
for i, e in enumerate(self.entities_definitions):
|
||||
Entity = getattr(entities, e['type'])
|
||||
for i in range(e.get('num', 1) + int(self.random()*(0.99+e.get('num_rand', 0)))):
|
||||
entity = Entity(self)
|
||||
conf = {k: v for k, v in e.items() if str(
|
||||
k) not in ['num', 'num_rand', 'type']}
|
||||
|
||||
for k, v_raw in conf.items():
|
||||
if k == 'pos':
|
||||
v = self.conv_unit(v_raw[0], target='em', axis='x'), self.conv_unit(
|
||||
v_raw[1], target='em', axis='y')
|
||||
elif k in ['width', 'height', 'radius']:
|
||||
v = self.conv_unit(
|
||||
v_raw, target='px', axis='y' if k == 'height' else 'x')
|
||||
else:
|
||||
v = v_raw
|
||||
if k.endswith('_rand'):
|
||||
n = k.replace('_rand', '')
|
||||
cur = getattr(
|
||||
entity, n)
|
||||
inc = int((v+0.99)*self.random())
|
||||
setattr(entity, n, cur + inc)
|
||||
elif k.endswith('_randf'):
|
||||
n = k.replace('_randf', '')
|
||||
cur = getattr(
|
||||
entity, n)
|
||||
inc = v*self.random()
|
||||
setattr(entity, n, cur + inc)
|
||||
else:
|
||||
setattr(entity, k, v)
|
||||
|
||||
self.entities.append(entity)
|
||||
|
||||
###
|
||||
# Custom Env Definitions
|
||||
|
||||
|
||||
class ColumbusTest3_1(ColumbusEnv):
|
||||
def __init__(self, observable=observables.CnnObservable(out_width=48, out_height=48), fps=30, aux_reward_max=1, **kw):
|
||||
super(ColumbusTest3_1, self).__init__(
|
||||
@ -833,89 +924,6 @@ class ColumbusFootball(ColumbusEnv):
|
||||
self.entities.append(entities.FlyingFootballPlayer(self, ball))
|
||||
|
||||
|
||||
class ColumbusConfigDefined(ColumbusEnv):
|
||||
def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw):
|
||||
super().__init__(
|
||||
observable=observable, fps=fps, env_seed=env_seed, **kw)
|
||||
self.entities_definitions = entities
|
||||
|
||||
def is_unit(self, s):
|
||||
if type(s) in [int, float]:
|
||||
return True
|
||||
if s.replace('.', '', 1).isdigit():
|
||||
return True
|
||||
num, unit = s[:-2], s[-2:]
|
||||
if unit in ['px', 'em', 'rx', 'ry', 'ct']:
|
||||
if num.replace('.', '', 1).isdigit():
|
||||
return True
|
||||
return False
|
||||
|
||||
def conv_unit(self, s, target='px', axis='x'):
|
||||
assert self.is_unit(s)
|
||||
if type(s) in [int, float]:
|
||||
return s
|
||||
if s.replace('.', '', 1).isdigit():
|
||||
if target == 'px':
|
||||
return int(s)
|
||||
return float(s)
|
||||
num, unit = s[:-2], s[-2:]
|
||||
num = float(num)
|
||||
if unit == 'rx':
|
||||
unit = 'px'
|
||||
axis = 'x'
|
||||
elif unit == 'ry':
|
||||
unit = 'px'
|
||||
axis = 'y'
|
||||
if unit == 'em':
|
||||
em = num
|
||||
elif unit == 'px':
|
||||
em = num / ({'x': self.width, 'y': self.height}[axis])
|
||||
elif unit == 'ct':
|
||||
em = num / 100
|
||||
else:
|
||||
raise Exception('Conversion not implemented')
|
||||
|
||||
if target == 'em':
|
||||
return em
|
||||
elif target == 'px':
|
||||
return int(em * ({'x': self.width, 'y': self.height}[axis]))
|
||||
|
||||
def setup(self):
|
||||
self.agent.pos = self.start_pos
|
||||
for i, e in enumerate(self.entities_definitions):
|
||||
Entity = getattr(entities, e['type'])
|
||||
for i in range(e.get('num', 1) + int(self.random()*(0.99+e.get('num_rand', 0)))):
|
||||
entity = Entity(self)
|
||||
conf = {k: v for k, v in e.items() if str(
|
||||
k) not in ['num', 'num_rand', 'type']}
|
||||
|
||||
for k, v_raw in conf.items():
|
||||
if k == 'pos':
|
||||
v = self.conv_unit(v_raw[0], target='em', axis='x'), self.conv_unit(
|
||||
v_raw[1], target='em', axis='y')
|
||||
elif k in ['width', 'height', 'radius']:
|
||||
v = self.conv_unit(
|
||||
v_raw, target='px', axis='y' if k == 'height' else 'x')
|
||||
else:
|
||||
v = v_raw
|
||||
if k.endswith('_rand'):
|
||||
n = k.replace('_rand', '')
|
||||
cur = getattr(
|
||||
entity, n)
|
||||
inc = int((v+0.99)*self.random())
|
||||
setattr(entity, n, cur + inc)
|
||||
elif k.endswith('_randf'):
|
||||
n = k.replace('_randf', '')
|
||||
cur = getattr(
|
||||
entity, n)
|
||||
inc = v*self.random()
|
||||
setattr(entity, n, cur + inc)
|
||||
else:
|
||||
setattr(entity, k, v)
|
||||
|
||||
self.entities.append(entity)
|
||||
|
||||
|
||||
class ColumbusBlub(ColumbusEnv):
|
||||
def __init__(self, observable=observables.CompositionalObservable([observables.StateObservable(), observables.RayObservable(num_rays=6, chans=[entities.Enemy])]), env_seed=None, entities=[], fps=30, **kw):
|
||||
super().__init__(
|
||||
@ -929,7 +937,15 @@ class ColumbusBlub(ColumbusEnv):
|
||||
enemy.width, enemy.height = 200, 75
|
||||
self.entities.append(enemy)
|
||||
|
||||
|
||||
###
|
||||
# Registering Envs fro Gym
|
||||
register(
|
||||
id='ColumbusConfigDefined-v0',
|
||||
entry_point=ColumbusConfigDefined,
|
||||
max_episode_steps=30*60*2, # 2 min at default (30) fps
|
||||
)
|
||||
|
||||
# register(
|
||||
# id='ColumbusBlub-v0',
|
||||
# entry_point=ColumbusBlub,
|
||||
@ -1021,12 +1037,6 @@ register(
|
||||
# max_episode_steps=30*60*2,
|
||||
# )
|
||||
|
||||
register(
|
||||
id='ColumbusConfigDefined-v0',
|
||||
entry_point=ColumbusConfigDefined,
|
||||
max_episode_steps=30*60*2,
|
||||
)
|
||||
|
||||
register(
|
||||
id='ColumbusDemoEnvFootball-v0',
|
||||
entry_point=ColumbusDemoEnvFootball,
|
||||
|
Loading…
Reference in New Issue
Block a user