Smashed bugs regarding StateObservable giving wrong data when not
rendering
This commit is contained in:
parent
bfbfe9bb43
commit
f94eaa5dc0
@ -192,6 +192,8 @@ class ColumbusEnv(gym.Env):
|
|||||||
1 and self.score > self.return_on_score
|
1 and self.score > self.return_on_score
|
||||||
info = {'score': self.score, 'reward': reward}
|
info = {'score': self.score, 'reward': reward}
|
||||||
self._rendered = False
|
self._rendered = False
|
||||||
|
if done:
|
||||||
|
self.reset()
|
||||||
return observation, reward*self.reward_mult, done, info
|
return observation, reward*self.reward_mult, done, info
|
||||||
|
|
||||||
def check_collisions_for(self, entity):
|
def check_collisions_for(self, entity):
|
||||||
@ -242,7 +244,7 @@ class ColumbusEnv(gym.Env):
|
|||||||
self.agent = entities.Agent(self)
|
self.agent = entities.Agent(self)
|
||||||
self.setup()
|
self.setup()
|
||||||
self.entities.append(self.agent) # add it last, will be drawn on top
|
self.entities.append(self.agent) # add it last, will be drawn on top
|
||||||
self.observable._entities = None
|
self.observable.reset()
|
||||||
return self.observable.get_observation()
|
return self.observable.get_observation()
|
||||||
|
|
||||||
def _draw_entities(self):
|
def _draw_entities(self):
|
||||||
|
@ -39,7 +39,6 @@ def chooseEnv():
|
|||||||
|
|
||||||
|
|
||||||
def playEnv(env):
|
def playEnv(env):
|
||||||
env.reset()
|
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
t1 = time()
|
t1 = time()
|
||||||
|
@ -24,6 +24,9 @@ class Observable():
|
|||||||
def draw(self):
|
def draw(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CnnObservable(Observable):
|
class CnnObservable(Observable):
|
||||||
def __init__(self, in_width=256, in_height=256, out_width=32, out_height=32, draw_width=128, draw_height=128, smooth_scaling=True):
|
def __init__(self, in_width=256, in_height=256, out_width=32, out_height=32, draw_width=128, draw_height=128, smooth_scaling=True):
|
||||||
@ -216,6 +219,9 @@ class StateObservable(Observable):
|
|||||||
self._timeoutEntities.append(entity)
|
self._timeoutEntities.append(entity)
|
||||||
return self._entities
|
return self._entities
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._entities = None
|
||||||
|
|
||||||
def get_observation_space(self):
|
def get_observation_space(self):
|
||||||
self.env.reset()
|
self.env.reset()
|
||||||
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
num = len(self.entities)*2+len(self._timeoutEntities) + \
|
||||||
@ -363,3 +369,7 @@ class CompositionalObservable(Observable):
|
|||||||
# self.env = env
|
# self.env = env
|
||||||
for obs in self.observables:
|
for obs in self.observables:
|
||||||
obs._set_env(env)
|
obs._set_env(env)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for obs in self.observables:
|
||||||
|
obs.reset()
|
||||||
|
Loading…
Reference in New Issue
Block a user