diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 44d2c56..4c9a129 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -565,15 +565,15 @@ def _reward_fn(maze_id: str, dense: str) -> callable: else: if maze_id in ["Maze", "Push", "BlockMaze"]: return ( - lambda obs, goal: -0.001 + lambda obs, goal: 1.0 if np.linalg.norm(obs[:2] - goal) <= 0.6 - else 1.0 + else -0.0001 ) elif maze_id == "Fall": return ( - lambda obs, goal: -0.001 + lambda obs, goal: 1.0 if np.linalg.norm(obs[:3] - goal) <= 0.6 - else 1.0 + else -0.0001 ) else: raise NotImplementedError(f"Unknown maze id: {maze_id}")