diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 0e8cbc0..31918f3 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -48,7 +48,7 @@ class MazeEnv(gym.Env): self.t = 0 # time steps self._observe_blocks = self._task.OBSERVE_BLOCKS self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT - self._top_down_view = top_down_view + self._top_down_view = self._task.TOP_DOWN_VIEW self._restitution_coef = restitution_coef self._maze_structure = structure = self._task.create_maze() @@ -248,6 +248,10 @@ class MazeEnv(gym.Env): self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs) self.observation_space = self._get_obs_space() + @property + def has_extended_obs(self) -> bool: + return self._top_down_view or self._observe_blocks + def get_ori(self) -> float: return self.wrapped_env.get_ori() diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index a3feda3..7c19e98 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -52,6 +52,7 @@ class MazeTask(ABC): REWARD_THRESHOLD: float MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0) INNER_REWARD_SCALING: float = 0.01 + TOP_DOWN_VIEW: bool = False OBSERVE_BLOCKS: bool = False PUT_SPIN_NEAR_AGENT: bool = False @@ -114,6 +115,8 @@ class DistRewardUMaze(GoalRewardUMaze, DistRewardMixIn): class GoalRewardPush(GoalRewardUMaze): + TOP_DOWN_VIEW = True + def __init__(self, scale: float) -> None: super().__init__(scale) self.goals = [MazeGoal(np.array([0.0, 2.375 * scale]))] @@ -135,6 +138,8 @@ class DistRewardPush(GoalRewardPush, DistRewardMixIn): class GoalRewardFall(GoalRewardUMaze): + TOP_DOWN_VIEW = True + def __init__(self, scale: float) -> None: super().__init__(scale) self.goals = [MazeGoal(np.array([0.0, 3.375 * scale, 4.5]))] diff --git a/tests/test_envs.py b/tests/test_envs.py index 812aa29..76e8190 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -8,18 +8,22 @@ import mujoco_maze def test_ant_maze(maze_id): for i in range(2): env = gym.make(f"Ant{maze_id}-v{i}") - assert env.reset().shape == (30,) + s0 = env.reset() s, _, _, _ = env.step(env.action_space.sample()) - assert s.shape == (30,) + if not env.unwrapped._top_down_view: + assert s0.shape == (30,) + assert s.shape == (30,) @pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys()) def test_point_maze(maze_id): for i in range(2): env = gym.make(f"Point{maze_id}-v{i}") - assert env.reset().shape == (7,) + s0 = env.reset() s, _, _, _ = env.step(env.action_space.sample()) - assert s.shape == (7,) + if not env.unwrapped._top_down_view: + assert s0.shape == (7,) + assert s.shape == (7,) @pytest.mark.parametrize("v", [0, 1])