From 6d465c69c93761a7074c4d3063e731d5f47edd55 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 6 Dec 2022 12:16:42 +0100 Subject: [PATCH] Minor cleanup and reorganized some class definitions --- columbus/env.py | 188 +++++++++++++++++++++++++----------------------- 1 file changed, 99 insertions(+), 89 deletions(-) diff --git a/columbus/env.py b/columbus/env.py index acc8651..7c3d400 100644 --- a/columbus/env.py +++ b/columbus/env.py @@ -11,6 +11,7 @@ import torch as th def parseObs(obsConf): + # Parsing Observable Definitions if type(obsConf) == list: obs = [] for i, c in enumerate(obsConf): @@ -494,6 +495,96 @@ class ColumbusEnv(gym.Env): pygame.quit() +class ColumbusConfigDefined(ColumbusEnv): + # Allows defining Columbus Environments using dicts. + # Intended to be used in combination with cw2 configuration. + # Look into humanPlayer to see how this is supposed to be interfaced with. + + def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw): + super().__init__( + observable=observable, fps=fps, env_seed=env_seed, **kw) + self.entities_definitions = entities + + def is_unit(self, s): + if type(s) in [int, float]: + return True + if s.replace('.', '', 1).isdigit(): + return True + num, unit = s[:-2], s[-2:] + if unit in ['px', 'em', 'rx', 'ry', 'ct']: + if num.replace('.', '', 1).isdigit(): + return True + return False + + def conv_unit(self, s, target='px', axis='x'): + assert self.is_unit(s) + if type(s) in [int, float]: + return s + if s.replace('.', '', 1).isdigit(): + if target == 'px': + return int(s) + return float(s) + num, unit = s[:-2], s[-2:] + num = float(num) + if unit == 'rx': + unit = 'px' + axis = 'x' + elif unit == 'ry': + unit = 'px' + axis = 'y' + if unit == 'em': + em = num + elif unit == 'px': + em = num / ({'x': self.width, 'y': self.height}[axis]) + elif unit == 'ct': + em = num / 100 + else: + raise Exception('Conversion not implemented') + + if target == 'em': + return em + elif target == 'px': + return int(em * ({'x': self.width, 'y': self.height}[axis])) + + def setup(self): + self.agent.pos = self.start_pos + for i, e in enumerate(self.entities_definitions): + Entity = getattr(entities, e['type']) + for i in range(e.get('num', 1) + int(self.random()*(0.99+e.get('num_rand', 0)))): + entity = Entity(self) + conf = {k: v for k, v in e.items() if str( + k) not in ['num', 'num_rand', 'type']} + + for k, v_raw in conf.items(): + if k == 'pos': + v = self.conv_unit(v_raw[0], target='em', axis='x'), self.conv_unit( + v_raw[1], target='em', axis='y') + elif k in ['width', 'height', 'radius']: + v = self.conv_unit( + v_raw, target='px', axis='y' if k == 'height' else 'x') + else: + v = v_raw + if k.endswith('_rand'): + n = k.replace('_rand', '') + cur = getattr( + entity, n) + inc = int((v+0.99)*self.random()) + setattr(entity, n, cur + inc) + elif k.endswith('_randf'): + n = k.replace('_randf', '') + cur = getattr( + entity, n) + inc = v*self.random() + setattr(entity, n, cur + inc) + else: + setattr(entity, k, v) + + self.entities.append(entity) + +### +# Custom Env Definitions + + class ColumbusTest3_1(ColumbusEnv): def __init__(self, observable=observables.CnnObservable(out_width=48, out_height=48), fps=30, aux_reward_max=1, **kw): super(ColumbusTest3_1, self).__init__( @@ -833,89 +924,6 @@ class ColumbusFootball(ColumbusEnv): self.entities.append(entities.FlyingFootballPlayer(self, ball)) -class ColumbusConfigDefined(ColumbusEnv): - def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw): - super().__init__( - observable=observable, fps=fps, env_seed=env_seed, **kw) - self.entities_definitions = entities - - def is_unit(self, s): - if type(s) in [int, float]: - return True - if s.replace('.', '', 1).isdigit(): - return True - num, unit = s[:-2], s[-2:] - if unit in ['px', 'em', 'rx', 'ry', 'ct']: - if num.replace('.', '', 1).isdigit(): - return True - return False - - def conv_unit(self, s, target='px', axis='x'): - assert self.is_unit(s) - if type(s) in [int, float]: - return s - if s.replace('.', '', 1).isdigit(): - if target == 'px': - return int(s) - return float(s) - num, unit = s[:-2], s[-2:] - num = float(num) - if unit == 'rx': - unit = 'px' - axis = 'x' - elif unit == 'ry': - unit = 'px' - axis = 'y' - if unit == 'em': - em = num - elif unit == 'px': - em = num / ({'x': self.width, 'y': self.height}[axis]) - elif unit == 'ct': - em = num / 100 - else: - raise Exception('Conversion not implemented') - - if target == 'em': - return em - elif target == 'px': - return int(em * ({'x': self.width, 'y': self.height}[axis])) - - def setup(self): - self.agent.pos = self.start_pos - for i, e in enumerate(self.entities_definitions): - Entity = getattr(entities, e['type']) - for i in range(e.get('num', 1) + int(self.random()*(0.99+e.get('num_rand', 0)))): - entity = Entity(self) - conf = {k: v for k, v in e.items() if str( - k) not in ['num', 'num_rand', 'type']} - - for k, v_raw in conf.items(): - if k == 'pos': - v = self.conv_unit(v_raw[0], target='em', axis='x'), self.conv_unit( - v_raw[1], target='em', axis='y') - elif k in ['width', 'height', 'radius']: - v = self.conv_unit( - v_raw, target='px', axis='y' if k == 'height' else 'x') - else: - v = v_raw - if k.endswith('_rand'): - n = k.replace('_rand', '') - cur = getattr( - entity, n) - inc = int((v+0.99)*self.random()) - setattr(entity, n, cur + inc) - elif k.endswith('_randf'): - n = k.replace('_randf', '') - cur = getattr( - entity, n) - inc = v*self.random() - setattr(entity, n, cur + inc) - else: - setattr(entity, k, v) - - self.entities.append(entity) - - class ColumbusBlub(ColumbusEnv): def __init__(self, observable=observables.CompositionalObservable([observables.StateObservable(), observables.RayObservable(num_rays=6, chans=[entities.Enemy])]), env_seed=None, entities=[], fps=30, **kw): super().__init__( @@ -929,7 +937,15 @@ class ColumbusBlub(ColumbusEnv): enemy.width, enemy.height = 200, 75 self.entities.append(enemy) + ### +# Registering Envs fro Gym +register( + id='ColumbusConfigDefined-v0', + entry_point=ColumbusConfigDefined, + max_episode_steps=30*60*2, # 2 min at default (30) fps +) + # register( # id='ColumbusBlub-v0', # entry_point=ColumbusBlub, @@ -1021,12 +1037,6 @@ register( # max_episode_steps=30*60*2, # ) -register( - id='ColumbusConfigDefined-v0', - entry_point=ColumbusConfigDefined, - max_episode_steps=30*60*2, -) - register( id='ColumbusDemoEnvFootball-v0', entry_point=ColumbusDemoEnvFootball,