diff --git a/columbus/env.py b/columbus/env.py index 65f5778..89afb95 100644 --- a/columbus/env.py +++ b/columbus/env.py @@ -192,6 +192,8 @@ class ColumbusEnv(gym.Env): 1 and self.score > self.return_on_score info = {'score': self.score, 'reward': reward} self._rendered = False + if done: + self.reset() return observation, reward*self.reward_mult, done, info def check_collisions_for(self, entity): @@ -242,7 +244,7 @@ class ColumbusEnv(gym.Env): self.agent = entities.Agent(self) self.setup() self.entities.append(self.agent) # add it last, will be drawn on top - self.observable._entities = None + self.observable.reset() return self.observable.get_observation() def _draw_entities(self): diff --git a/columbus/humanPlayer.py b/columbus/humanPlayer.py index 4420a2c..b46e811 100644 --- a/columbus/humanPlayer.py +++ b/columbus/humanPlayer.py @@ -39,7 +39,6 @@ def chooseEnv(): def playEnv(env): - env.reset() done = False while not done: t1 = time() diff --git a/columbus/observables.py b/columbus/observables.py index 53ec1ee..eeebe8a 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -24,6 +24,9 @@ class Observable(): def draw(self): pass + def reset(self): + pass + class CnnObservable(Observable): def __init__(self, in_width=256, in_height=256, out_width=32, out_height=32, draw_width=128, draw_height=128, smooth_scaling=True): @@ -216,6 +219,9 @@ class StateObservable(Observable): self._timeoutEntities.append(entity) return self._entities + def reset(self): + self._entities = None + def get_observation_space(self): self.env.reset() num = len(self.entities)*2+len(self._timeoutEntities) + \ @@ -363,3 +369,7 @@ class CompositionalObservable(Observable): # self.env = env for obs in self.observables: obs._set_env(env) + + def reset(self): + for obs in self.observables: + obs.reset()