From ff4e81d4f1fc3aed1f46decf9a7fab924f1f113a Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 25 Aug 2022 13:38:59 +0200 Subject: [PATCH] Added a dummy Observable --- columbus/env.py | 4 ++++ columbus/humanPlayer.py | 1 + columbus/observables.py | 5 ++--- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/columbus/env.py b/columbus/env.py index 89afb95..00b01ea 100644 --- a/columbus/env.py +++ b/columbus/env.py @@ -35,6 +35,9 @@ def parseObs(obsConf): elif obsConf['type'] == 'CNN': conf = {k: v for k, v in obsConf.items() if k not in ['type']} return observables.CnnObservable(**conf) + elif obsConf['type'] == 'Dummy': + conf = {k: v for k, v in obsConf.items() if k not in ['type']} + return observables.Observable(**conf) else: raise Exception('Unknown Observable selected') @@ -84,6 +87,7 @@ class ColumbusEnv(gym.Env): self.void_barrier = void_is_type_barrier self.void_damage = void_damage self.torus_topology = torus_topology + self.default_collision_elasticity = 1 self.paused = False self.keypress_timeout = 0 diff --git a/columbus/humanPlayer.py b/columbus/humanPlayer.py index b46e811..0bcda40 100644 --- a/columbus/humanPlayer.py +++ b/columbus/humanPlayer.py @@ -40,6 +40,7 @@ def chooseEnv(): def playEnv(env): done = False + env.reset() while not done: t1 = time() env.render() diff --git a/columbus/observables.py b/columbus/observables.py index eeebe8a..07e3b39 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -13,7 +13,7 @@ class Observable(): def _set_env(self, env): self.env = env - def get_observation_space(): + def get_observation_space(self): print("[!] Using dummyObservable. Env won't output anything") return spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32) @@ -223,7 +223,7 @@ class StateObservable(Observable): self._entities = None def get_observation_space(self): - self.env.reset() + self.reset() num = len(self.entities)*2+len(self._timeoutEntities) + \ self.speedAgent*2 + self.include_rand return spaces.Box(low=0-1*self.coordsRelativeToAgent, high=1, @@ -366,7 +366,6 @@ class CompositionalObservable(Observable): obs.draw() def _set_env(self, env): - # self.env = env for obs in self.observables: obs._set_env(env)