From f94eaa5dc0af982bcb82fbdeb0b6bf37fff7b772 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 22 Aug 2022 18:53:30 +0200 Subject: [PATCH] Smashed bugs regarding StateObservable giving wrong data when not rendering --- columbus/env.py | 4 +++- columbus/humanPlayer.py | 1 - columbus/observables.py | 10 ++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/columbus/env.py b/columbus/env.py index 65f5778..89afb95 100644 --- a/columbus/env.py +++ b/columbus/env.py @@ -192,6 +192,8 @@ class ColumbusEnv(gym.Env): 1 and self.score > self.return_on_score info = {'score': self.score, 'reward': reward} self._rendered = False + if done: + self.reset() return observation, reward*self.reward_mult, done, info def check_collisions_for(self, entity): @@ -242,7 +244,7 @@ class ColumbusEnv(gym.Env): self.agent = entities.Agent(self) self.setup() self.entities.append(self.agent) # add it last, will be drawn on top - self.observable._entities = None + self.observable.reset() return self.observable.get_observation() def _draw_entities(self): diff --git a/columbus/humanPlayer.py b/columbus/humanPlayer.py index 4420a2c..b46e811 100644 --- a/columbus/humanPlayer.py +++ b/columbus/humanPlayer.py @@ -39,7 +39,6 @@ def chooseEnv(): def playEnv(env): - env.reset() done = False while not done: t1 = time() diff --git a/columbus/observables.py b/columbus/observables.py index 53ec1ee..eeebe8a 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -24,6 +24,9 @@ class Observable(): def draw(self): pass + def reset(self): + pass + class CnnObservable(Observable): def __init__(self, in_width=256, in_height=256, out_width=32, out_height=32, draw_width=128, draw_height=128, smooth_scaling=True): @@ -216,6 +219,9 @@ class StateObservable(Observable): self._timeoutEntities.append(entity) return self._entities + def reset(self): + self._entities = None + def get_observation_space(self): self.env.reset() num = len(self.entities)*2+len(self._timeoutEntities) + \ @@ -363,3 +369,7 @@ class CompositionalObservable(Observable): # self.env = env for obs in self.observables: obs._set_env(env) + + def reset(self): + for obs in self.observables: + obs.reset()