Invalidate value-map upon reward-change

This commit is contained in:
Dominik Moritz Roth 2022-11-01 16:14:28 +01:00
parent 0fe5c35dda
commit 555a4780e3

View File

@ -368,6 +368,9 @@ class TeleportingReward(OnceReward):
self.env.check_collisions_for(self) self.env.check_collisions_for(self)
def on_collected(self): def on_collected(self):
# Force rerender of value func (even in static envs)
self.env._invalidate_value_map()
self.env.new_abs_reward += self.reward self.env.new_abs_reward += self.reward
self.pos = (self.env.random(), self.env.random()) self.pos = (self.env.random(), self.env.random())
self.env.check_collisions_for(self) self.env.check_collisions_for(self)
@ -382,6 +385,9 @@ class LoopReward(OnceReward):
self.barrier_physics = False self.barrier_physics = False
def jump_to_state(self): def jump_to_state(self):
# Force rerender of value func (even in static envs)
self.env._invalidate_value_map()
pos_vec = [v for v in self.loop[self.state]] pos_vec = [v for v in self.loop[self.state]]
if len(pos_vec) == 4: if len(pos_vec) == 4:
pos_vec = pos_vec[0] + pos_vec[2] * \ pos_vec = pos_vec[0] + pos_vec[2] * \
@ -422,6 +428,9 @@ class TimeoutReward(OnceReward):
def on_collected(self): def on_collected(self):
if self.avaible: if self.avaible:
# Force rerender of value func (even in static envs)
self.env._invalidate_value_map()
self.env.new_abs_reward += self.reward self.env.new_abs_reward += self.reward
self.set_avaible(False) self.set_avaible(False)
self.env.timers.append((self.timeout, self.set_avaible, True)) self.env.timers.append((self.timeout, self.set_avaible, True))
@ -468,6 +477,9 @@ class TeleportingGoal(Goal):
self.env.check_collisions_for(self) self.env.check_collisions_for(self)
def on_collected(self): def on_collected(self):
# Force rerender of value func (even in static envs)
self.env._invalidate_value_map()
self.env.new_abs_reward += self.reward self.env.new_abs_reward += self.reward
self.pos = (self.env.random(), self.env.random()) self.pos = (self.env.random(), self.env.random())
self.env.check_collisions_for(self) self.env.check_collisions_for(self)