diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 3a8b15b..44d2c56 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -535,7 +535,7 @@ class MazeEnv(gym.Env): old_pos = self.wrapped_env.get_xy() inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) new_pos = self.wrapped_env.get_xy() - if self._collision.is_in(new_pos): + if self._collision.is_in(old_pos, new_pos): self.wrapped_env.set_xy(old_pos) else: inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) diff --git a/mujoco_maze/maze_env_utils.py b/mujoco_maze/maze_env_utils.py index b5f50e8..251d60f 100644 --- a/mujoco_maze/maze_env_utils.py +++ b/mujoco_maze/maze_env_utils.py @@ -134,11 +134,12 @@ class Collision: max_x = x_base + size_scaling * offset(pos, 3) self.objects.append((min_y, max_y, min_x, max_x)) - def is_in(self, new_pos) -> bool: - x, y = new_pos - for min_y, max_y, min_x, max_x in self.objects: - if min_x <= x <= max_x and min_y <= y <= max_y: - return True + def is_in(self, old_pos, new_pos) -> bool: + # Heuristics to prevent the agent from going through the wall + for x, y in ((old_pos + new_pos) / 2, new_pos): + for min_y, max_y, min_x, max_x in self.objects: + if min_x <= x <= max_x and min_y <= y <= max_y: + return True return False