Minor cleanup and reorganized some class definitions
This commit is contained in:
parent
8132bb9321
commit
6d465c69c9
188
columbus/env.py
188
columbus/env.py
@ -11,6 +11,7 @@ import torch as th
|
|||||||
|
|
||||||
|
|
||||||
def parseObs(obsConf):
|
def parseObs(obsConf):
|
||||||
|
# Parsing Observable Definitions
|
||||||
if type(obsConf) == list:
|
if type(obsConf) == list:
|
||||||
obs = []
|
obs = []
|
||||||
for i, c in enumerate(obsConf):
|
for i, c in enumerate(obsConf):
|
||||||
@ -494,6 +495,96 @@ class ColumbusEnv(gym.Env):
|
|||||||
pygame.quit()
|
pygame.quit()
|
||||||
|
|
||||||
|
|
||||||
|
class ColumbusConfigDefined(ColumbusEnv):
|
||||||
|
# Allows defining Columbus Environments using dicts.
|
||||||
|
# Intended to be used in combination with cw2 configuration.
|
||||||
|
# Look into humanPlayer to see how this is supposed to be interfaced with.
|
||||||
|
|
||||||
|
def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw):
|
||||||
|
super().__init__(
|
||||||
|
observable=observable, fps=fps, env_seed=env_seed, **kw)
|
||||||
|
self.entities_definitions = entities
|
||||||
|
|
||||||
|
def is_unit(self, s):
|
||||||
|
if type(s) in [int, float]:
|
||||||
|
return True
|
||||||
|
if s.replace('.', '', 1).isdigit():
|
||||||
|
return True
|
||||||
|
num, unit = s[:-2], s[-2:]
|
||||||
|
if unit in ['px', 'em', 'rx', 'ry', 'ct']:
|
||||||
|
if num.replace('.', '', 1).isdigit():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def conv_unit(self, s, target='px', axis='x'):
|
||||||
|
assert self.is_unit(s)
|
||||||
|
if type(s) in [int, float]:
|
||||||
|
return s
|
||||||
|
if s.replace('.', '', 1).isdigit():
|
||||||
|
if target == 'px':
|
||||||
|
return int(s)
|
||||||
|
return float(s)
|
||||||
|
num, unit = s[:-2], s[-2:]
|
||||||
|
num = float(num)
|
||||||
|
if unit == 'rx':
|
||||||
|
unit = 'px'
|
||||||
|
axis = 'x'
|
||||||
|
elif unit == 'ry':
|
||||||
|
unit = 'px'
|
||||||
|
axis = 'y'
|
||||||
|
if unit == 'em':
|
||||||
|
em = num
|
||||||
|
elif unit == 'px':
|
||||||
|
em = num / ({'x': self.width, 'y': self.height}[axis])
|
||||||
|
elif unit == 'ct':
|
||||||
|
em = num / 100
|
||||||
|
else:
|
||||||
|
raise Exception('Conversion not implemented')
|
||||||
|
|
||||||
|
if target == 'em':
|
||||||
|
return em
|
||||||
|
elif target == 'px':
|
||||||
|
return int(em * ({'x': self.width, 'y': self.height}[axis]))
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.agent.pos = self.start_pos
|
||||||
|
for i, e in enumerate(self.entities_definitions):
|
||||||
|
Entity = getattr(entities, e['type'])
|
||||||
|
for i in range(e.get('num', 1) + int(self.random()*(0.99+e.get('num_rand', 0)))):
|
||||||
|
entity = Entity(self)
|
||||||
|
conf = {k: v for k, v in e.items() if str(
|
||||||
|
k) not in ['num', 'num_rand', 'type']}
|
||||||
|
|
||||||
|
for k, v_raw in conf.items():
|
||||||
|
if k == 'pos':
|
||||||
|
v = self.conv_unit(v_raw[0], target='em', axis='x'), self.conv_unit(
|
||||||
|
v_raw[1], target='em', axis='y')
|
||||||
|
elif k in ['width', 'height', 'radius']:
|
||||||
|
v = self.conv_unit(
|
||||||
|
v_raw, target='px', axis='y' if k == 'height' else 'x')
|
||||||
|
else:
|
||||||
|
v = v_raw
|
||||||
|
if k.endswith('_rand'):
|
||||||
|
n = k.replace('_rand', '')
|
||||||
|
cur = getattr(
|
||||||
|
entity, n)
|
||||||
|
inc = int((v+0.99)*self.random())
|
||||||
|
setattr(entity, n, cur + inc)
|
||||||
|
elif k.endswith('_randf'):
|
||||||
|
n = k.replace('_randf', '')
|
||||||
|
cur = getattr(
|
||||||
|
entity, n)
|
||||||
|
inc = v*self.random()
|
||||||
|
setattr(entity, n, cur + inc)
|
||||||
|
else:
|
||||||
|
setattr(entity, k, v)
|
||||||
|
|
||||||
|
self.entities.append(entity)
|
||||||
|
|
||||||
|
###
|
||||||
|
# Custom Env Definitions
|
||||||
|
|
||||||
|
|
||||||
class ColumbusTest3_1(ColumbusEnv):
|
class ColumbusTest3_1(ColumbusEnv):
|
||||||
def __init__(self, observable=observables.CnnObservable(out_width=48, out_height=48), fps=30, aux_reward_max=1, **kw):
|
def __init__(self, observable=observables.CnnObservable(out_width=48, out_height=48), fps=30, aux_reward_max=1, **kw):
|
||||||
super(ColumbusTest3_1, self).__init__(
|
super(ColumbusTest3_1, self).__init__(
|
||||||
@ -833,89 +924,6 @@ class ColumbusFootball(ColumbusEnv):
|
|||||||
self.entities.append(entities.FlyingFootballPlayer(self, ball))
|
self.entities.append(entities.FlyingFootballPlayer(self, ball))
|
||||||
|
|
||||||
|
|
||||||
class ColumbusConfigDefined(ColumbusEnv):
|
|
||||||
def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw):
|
|
||||||
super().__init__(
|
|
||||||
observable=observable, fps=fps, env_seed=env_seed, **kw)
|
|
||||||
self.entities_definitions = entities
|
|
||||||
|
|
||||||
def is_unit(self, s):
|
|
||||||
if type(s) in [int, float]:
|
|
||||||
return True
|
|
||||||
if s.replace('.', '', 1).isdigit():
|
|
||||||
return True
|
|
||||||
num, unit = s[:-2], s[-2:]
|
|
||||||
if unit in ['px', 'em', 'rx', 'ry', 'ct']:
|
|
||||||
if num.replace('.', '', 1).isdigit():
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def conv_unit(self, s, target='px', axis='x'):
|
|
||||||
assert self.is_unit(s)
|
|
||||||
if type(s) in [int, float]:
|
|
||||||
return s
|
|
||||||
if s.replace('.', '', 1).isdigit():
|
|
||||||
if target == 'px':
|
|
||||||
return int(s)
|
|
||||||
return float(s)
|
|
||||||
num, unit = s[:-2], s[-2:]
|
|
||||||
num = float(num)
|
|
||||||
if unit == 'rx':
|
|
||||||
unit = 'px'
|
|
||||||
axis = 'x'
|
|
||||||
elif unit == 'ry':
|
|
||||||
unit = 'px'
|
|
||||||
axis = 'y'
|
|
||||||
if unit == 'em':
|
|
||||||
em = num
|
|
||||||
elif unit == 'px':
|
|
||||||
em = num / ({'x': self.width, 'y': self.height}[axis])
|
|
||||||
elif unit == 'ct':
|
|
||||||
em = num / 100
|
|
||||||
else:
|
|
||||||
raise Exception('Conversion not implemented')
|
|
||||||
|
|
||||||
if target == 'em':
|
|
||||||
return em
|
|
||||||
elif target == 'px':
|
|
||||||
return int(em * ({'x': self.width, 'y': self.height}[axis]))
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.agent.pos = self.start_pos
|
|
||||||
for i, e in enumerate(self.entities_definitions):
|
|
||||||
Entity = getattr(entities, e['type'])
|
|
||||||
for i in range(e.get('num', 1) + int(self.random()*(0.99+e.get('num_rand', 0)))):
|
|
||||||
entity = Entity(self)
|
|
||||||
conf = {k: v for k, v in e.items() if str(
|
|
||||||
k) not in ['num', 'num_rand', 'type']}
|
|
||||||
|
|
||||||
for k, v_raw in conf.items():
|
|
||||||
if k == 'pos':
|
|
||||||
v = self.conv_unit(v_raw[0], target='em', axis='x'), self.conv_unit(
|
|
||||||
v_raw[1], target='em', axis='y')
|
|
||||||
elif k in ['width', 'height', 'radius']:
|
|
||||||
v = self.conv_unit(
|
|
||||||
v_raw, target='px', axis='y' if k == 'height' else 'x')
|
|
||||||
else:
|
|
||||||
v = v_raw
|
|
||||||
if k.endswith('_rand'):
|
|
||||||
n = k.replace('_rand', '')
|
|
||||||
cur = getattr(
|
|
||||||
entity, n)
|
|
||||||
inc = int((v+0.99)*self.random())
|
|
||||||
setattr(entity, n, cur + inc)
|
|
||||||
elif k.endswith('_randf'):
|
|
||||||
n = k.replace('_randf', '')
|
|
||||||
cur = getattr(
|
|
||||||
entity, n)
|
|
||||||
inc = v*self.random()
|
|
||||||
setattr(entity, n, cur + inc)
|
|
||||||
else:
|
|
||||||
setattr(entity, k, v)
|
|
||||||
|
|
||||||
self.entities.append(entity)
|
|
||||||
|
|
||||||
|
|
||||||
class ColumbusBlub(ColumbusEnv):
|
class ColumbusBlub(ColumbusEnv):
|
||||||
def __init__(self, observable=observables.CompositionalObservable([observables.StateObservable(), observables.RayObservable(num_rays=6, chans=[entities.Enemy])]), env_seed=None, entities=[], fps=30, **kw):
|
def __init__(self, observable=observables.CompositionalObservable([observables.StateObservable(), observables.RayObservable(num_rays=6, chans=[entities.Enemy])]), env_seed=None, entities=[], fps=30, **kw):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -929,7 +937,15 @@ class ColumbusBlub(ColumbusEnv):
|
|||||||
enemy.width, enemy.height = 200, 75
|
enemy.width, enemy.height = 200, 75
|
||||||
self.entities.append(enemy)
|
self.entities.append(enemy)
|
||||||
|
|
||||||
|
|
||||||
###
|
###
|
||||||
|
# Registering Envs fro Gym
|
||||||
|
register(
|
||||||
|
id='ColumbusConfigDefined-v0',
|
||||||
|
entry_point=ColumbusConfigDefined,
|
||||||
|
max_episode_steps=30*60*2, # 2 min at default (30) fps
|
||||||
|
)
|
||||||
|
|
||||||
# register(
|
# register(
|
||||||
# id='ColumbusBlub-v0',
|
# id='ColumbusBlub-v0',
|
||||||
# entry_point=ColumbusBlub,
|
# entry_point=ColumbusBlub,
|
||||||
@ -1021,12 +1037,6 @@ register(
|
|||||||
# max_episode_steps=30*60*2,
|
# max_episode_steps=30*60*2,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
register(
|
|
||||||
id='ColumbusConfigDefined-v0',
|
|
||||||
entry_point=ColumbusConfigDefined,
|
|
||||||
max_episode_steps=30*60*2,
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ColumbusDemoEnvFootball-v0',
|
id='ColumbusDemoEnvFootball-v0',
|
||||||
entry_point=ColumbusDemoEnvFootball,
|
entry_point=ColumbusDemoEnvFootball,
|
||||||
|
Loading…
Reference in New Issue
Block a user