diff --git a/columbus/env.py b/columbus/env.py index 922a6fb..00f8bc7 100644 --- a/columbus/env.py +++ b/columbus/env.py @@ -431,6 +431,26 @@ class ColumbusEasierObstacles(ColumbusEnv): self.entities.append(enemy) +class ColumbusComp(ColumbusEnv): + def __init__(self, observable=observables.CompositionalObservable([observables.RayObservable(num_rays=6, chans=[entities.Enemy]), observables.StateObservable(coordsAgent=True, speedAgent=False, coordsRelativeToAgent=False, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=False, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=False, include_rand=True)]), hide_map=False, fps=30, env_seed=None): + super().__init__( + observable=observable, fps=fps, env_seed=env_seed) + self.draw_entities = not hide_map + self.aux_reward_max = 10 + + def setup(self): + self.agent.pos = self.start_pos + for i in range(5): + enemy = entities.CircleBarrier(self) + enemy.radius = 30 + self.random()*70 + self.entities.append(enemy) + for i in range(3): + reward = entities.TeleportingReward(self) + reward.radius = 30 + reward.reward *= 2 + self.entities.append(reward) + + class ColumbusJustState(ColumbusEnv): def __init__(self, observable=observables.StateObservable(), fps=30, num_enemies=0, num_rewards=1, env_seed=None): super(ColumbusJustState, self).__init__( @@ -573,3 +593,9 @@ register( entry_point=ColumbusFootball, max_episode_steps=30*60*2, ) + +register( + id='ColumbusComb-v0', + entry_point=ColumbusComp, + max_episode_steps=30*60*2, +) diff --git a/columbus/observables.py b/columbus/observables.py index a3135d4..37b9c12 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -259,3 +259,43 @@ class StateObservable(Observable): (0, y*self.env.height+ofs[0]), 1, width=0) pygame.draw.circle(self.env.screen, col, (x*self.env.width+ofs[1], 0), 1, width=0) + + +class CompositionalObservable(Observable): + def __init__(self, observables): + super().__init__() + self.observables = observables + + def get_observation_space(self): + num = 0 + low = 99999 + high = -99999 + for i, obs in enumerate(self.observables): + space = obs.get_observation_space() + num += math.prod(space.shape) + low = min(low, float(space.low[0])) + high = max(high, float(space.high[0])) + if False: + if not i: + low = space.low + high = space.high + else: + low = np.vstack((low, space.low)) + high = np.vstack((high, space.high)) + return spaces.Box(low=low, high=high, + shape=(num,), dtype=np.float32) + + def get_observation(self): + o = [float(point) + for obs in self.observables for point in obs.get_observation()] + o = np.array(o) + return o + + def draw(self): + for obs in self.observables: + obs.draw() + + def _set_env(self, env): + #self.env = env + for obs in self.observables: + obs._set_env(env)