Use a tiny heuristics in collision detection
This commit is contained in:
parent
33032dd48e
commit
e4d6338a30
@ -535,7 +535,7 @@ class MazeEnv(gym.Env):
|
||||
old_pos = self.wrapped_env.get_xy()
|
||||
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||
new_pos = self.wrapped_env.get_xy()
|
||||
if self._collision.is_in(new_pos):
|
||||
if self._collision.is_in(old_pos, new_pos):
|
||||
self.wrapped_env.set_xy(old_pos)
|
||||
else:
|
||||
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||
|
@ -134,11 +134,12 @@ class Collision:
|
||||
max_x = x_base + size_scaling * offset(pos, 3)
|
||||
self.objects.append((min_y, max_y, min_x, max_x))
|
||||
|
||||
def is_in(self, new_pos) -> bool:
|
||||
x, y = new_pos
|
||||
for min_y, max_y, min_x, max_x in self.objects:
|
||||
if min_x <= x <= max_x and min_y <= y <= max_y:
|
||||
return True
|
||||
def is_in(self, old_pos, new_pos) -> bool:
|
||||
# Heuristics to prevent the agent from going through the wall
|
||||
for x, y in ((old_pos + new_pos) / 2, new_pos):
|
||||
for min_y, max_y, min_x, max_x in self.objects:
|
||||
if min_x <= x <= max_x and min_y <= y <= max_y:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user