diff --git a/columbus/env.py b/columbus/env.py index 22b1761..509ddea 100644 --- a/columbus/env.py +++ b/columbus/env.py @@ -10,6 +10,35 @@ from columbus import entities, observables import torch as th +def parseObs(obsConf): + if type(obsConf) == list: + obs = [] + for i, c in enumerate(obsConf): + obs.append(parseObs(c)) + if len(obs) == 1: + return obs[0] + else: + return observables.CompositionalObservable(obs) + + if obsConf['type'] == 'State': + conf = {k: v for k, v in obsConf.items() if k not in ['type']} + return observables.StateObservable(conf) + elif obsConf['type'] == 'Compass': + conf = {k: v for k, v in obsConf.items() if k not in ['type']} + return observables.CompassObservable(**conf) + elif obsConf['type'] == 'RayCast': + chans = [] + for chan in obsConf.get('chans', []): + chans.append(getattr(entities, chan)) + conf = {k: v for k, v in obsConf.items() if k not in ['type', 'chans']} + return observables.RayObservable(chans=chans, **conf) + elif obsConf['type'] == 'CNN': + conf = {k: v for k, v in obsConf.items() if k not in ['type']} + return observables.CnnObservable(**conf) + else: + raise Exception('Unknown Observable selected') + + class ColumbusEnv(gym.Env): metadata = {'render.modes': ['human']} @@ -527,6 +556,30 @@ class ColumbusStateWithBarriers(ColumbusEnv): self.entities.append(reward) +class ColumbusCompassWithBarriers(ColumbusEnv): + def __init__(self, observable=observables.CompassObservable(coordsRewards=True), fps=30, env_seed=3.141, num_enemys=0, num_barriers=3, aux_reward_max=10, **kw): + super().__init__( + observable=observable, fps=fps, env_seed=env_seed, aux_reward_max=aux_reward_max, **kw) + self.start_pos = (0.5, 0.5) + self.num_barriers = num_barriers + self.num_enemys = num_enemys + + def setup(self): + self.agent.pos = self.start_pos + for i in range(self.num_barriers): + enemy = entities.CircleBarrier(self) + enemy.radius = self.random()*25+75 + self.entities.append(enemy) + for i in range(self.num_enemys): + enemy = entities.FlyingChaser(self) + enemy.chase_acc = 0.55 # *0.6+0.5 + self.entities.append(enemy) + for i in range(1): + reward = entities.TeleportingReward(self) + reward.radius = 30 + self.entities.append(reward) + + class ColumbusTrivialRay(ColumbusStateWithBarriers): def __init__(self, observable=observables.RayObservable(num_rays=8, ray_len=512), hide_map=False, fps=30, **kw): super(ColumbusTrivialRay, self).__init__( @@ -558,32 +611,6 @@ class ColumbusFootball(ColumbusEnv): self.entities.append(entities.FlyingFootballPlayer(self, ball)) -def parseObs(obsConf): - if type(obsConf) == list: - obs = [] - for i, c in enumerate(obsConf): - obs.append(parseObs(c)) - if len(obs) == 1: - return obs[0] - else: - return observables.CompositionalObservable(obs) - - if obsConf['type'] == 'State': - conf = {k: v for k, v in obsConf.items() if k not in ['type']} - return observables.StateObservable(conf) - elif obsConf['type'] == 'RayCast': - chans = [] - for chan in obsConf.get('chans', []): - chans.append(getattr(entities, chan)) - conf = {k: v for k, v in obsConf.items() if k not in ['type', 'chans']} - return observables.RayObservable(chans=chans, **conf) - elif obsConf['type'] == 'CNN': - conf = {k: v for k, v in obsConf.items() if k not in ['type']} - return observables.CnnObservable(**conf) - else: - raise Exception('Unknown Observable selected') - - class ColumbusConfigDefined(ColumbusEnv): def __init__(self, observable={}, env_seed=None, entities=[], fps=30, **kw): super().__init__( @@ -673,6 +700,12 @@ register( max_episode_steps=30*60*2, ) +register( + id='ColumbusCompassWithBarriers-v0', + entry_point=ColumbusCompassWithBarriers, + max_episode_steps=30*60*2, +) + register( id='ColumbusTrivialRay-v0', entry_point=ColumbusTrivialRay, diff --git a/columbus/observables.py b/columbus/observables.py index 77bc3a4..7a8222c 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -265,6 +265,71 @@ class StateObservable(Observable): (x*self.env.width+ofs[1], 0), 1, width=0) +class CompassObservable(Observable): + def __init__(self, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=False, enemysWhitelist=None, enemysNoBarriers=True): + super().__init__() + self._entities = None + self._timeoutEntities = [] + self.coordRewards = coordsRewards + self.rewardsWhitelist = rewardsWhitelist + self.coordsEnemys = coordsEnemys + self.enemysWhitelist = enemysWhitelist + self.enemysNoBarriers = enemysNoBarriers + + @property + def entities(self): + if not self._entities == None: + return self._entities + rewardsWhitelist = self.rewardsWhitelist or self.env.entities + enemysWhitelist = self.enemysWhitelist or self.env.entities + self._entities = [] + if self.coordRewards: + for entity in rewardsWhitelist: + if isinstance(entity, entities.Reward): + self._entities.append(entity) + if self.coordsEnemys: + for entity in enemysWhitelist: + if isinstance(entity, entities.Enemy): + if not self.enemysNoBarriers or not isinstance(entity, entities.Barrier): + self._entities.append(entity) + return self._entities + + def get_observation_space(self): + self.env.reset() + num = len(self.entities)*2 + return spaces.Box(low=-1, high=1, + shape=(num,), dtype=np.float32) + + def get_observation(self): + obs = [] + for entity in self.entities: + dx, dy = entity.pos[0] - \ + self.env.agent.pos[0], entity.pos[1] - self.env.agent.pos[1] + l = math.sqrt(dx**2 + dy**2)*2 + x, y = math.tanh(dx/l), math.tanh(dy/l) + obs.append(x) + obs.append(y) + + self.obs = obs + return np.array(obs) + + def draw(self): + ofs = (0 + self.env.height/2, + 0 + self.env.width/2) + if True: + pygame.draw.circle(self.env.screen, self.env.agent.col, + (0, self.env.height/2), 3, width=0) + pygame.draw.circle(self.env.screen, self.env.agent.col, + (self.env.width/2, 0), 3, width=0) + for i in range(int(len(self.obs)/2)): + x, y = self.obs[i*2], self.obs[i*2+1] + col = self.entities[i].col + pygame.draw.circle(self.env.screen, col, + (0, y*self.env.height+ofs[0]), 1, width=0) + pygame.draw.circle(self.env.screen, col, + (x*self.env.width+ofs[1], 0), 1, width=0) + + class CompositionalObservable(Observable): def __init__(self, observables): super().__init__()