Smashed bugs regarding StateObservable giving wrong data when not

rendering
This commit is contained in:
Dominik Moritz Roth 2022-08-22 18:53:30 +02:00
parent bfbfe9bb43
commit f94eaa5dc0
3 changed files with 13 additions and 2 deletions

View File

@ -192,6 +192,8 @@ class ColumbusEnv(gym.Env):
1 and self.score > self.return_on_score 1 and self.score > self.return_on_score
info = {'score': self.score, 'reward': reward} info = {'score': self.score, 'reward': reward}
self._rendered = False self._rendered = False
if done:
self.reset()
return observation, reward*self.reward_mult, done, info return observation, reward*self.reward_mult, done, info
def check_collisions_for(self, entity): def check_collisions_for(self, entity):
@ -242,7 +244,7 @@ class ColumbusEnv(gym.Env):
self.agent = entities.Agent(self) self.agent = entities.Agent(self)
self.setup() self.setup()
self.entities.append(self.agent) # add it last, will be drawn on top self.entities.append(self.agent) # add it last, will be drawn on top
self.observable._entities = None self.observable.reset()
return self.observable.get_observation() return self.observable.get_observation()
def _draw_entities(self): def _draw_entities(self):

View File

@ -39,7 +39,6 @@ def chooseEnv():
def playEnv(env): def playEnv(env):
env.reset()
done = False done = False
while not done: while not done:
t1 = time() t1 = time()

View File

@ -24,6 +24,9 @@ class Observable():
def draw(self): def draw(self):
pass pass
def reset(self):
pass
class CnnObservable(Observable): class CnnObservable(Observable):
def __init__(self, in_width=256, in_height=256, out_width=32, out_height=32, draw_width=128, draw_height=128, smooth_scaling=True): def __init__(self, in_width=256, in_height=256, out_width=32, out_height=32, draw_width=128, draw_height=128, smooth_scaling=True):
@ -216,6 +219,9 @@ class StateObservable(Observable):
self._timeoutEntities.append(entity) self._timeoutEntities.append(entity)
return self._entities return self._entities
def reset(self):
self._entities = None
def get_observation_space(self): def get_observation_space(self):
self.env.reset() self.env.reset()
num = len(self.entities)*2+len(self._timeoutEntities) + \ num = len(self.entities)*2+len(self._timeoutEntities) + \
@ -363,3 +369,7 @@ class CompositionalObservable(Observable):
# self.env = env # self.env = env
for obs in self.observables: for obs in self.observables:
obs._set_env(env) obs._set_env(env)
def reset(self):
for obs in self.observables:
obs.reset()