From 142b42e34f5a8fa5595b514b5248cf9919803758 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Tue, 30 Jun 2020 01:38:02 +0900 Subject: [PATCH] Add 2Rooms --- mujoco_maze/__init__.py | 10 +++---- mujoco_maze/maze_task.py | 64 +++++++++++++++++++++++++++++++++++++--- tests/test_envs.py | 4 +-- 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/mujoco_maze/__init__.py b/mujoco_maze/__init__.py index d390455..1ff8ecd 100644 --- a/mujoco_maze/__init__.py +++ b/mujoco_maze/__init__.py @@ -2,8 +2,6 @@ import gym from mujoco_maze.maze_task import TaskRegistry -MAZE_IDS = ["Maze", "Push", "Fall", "4Rooms"] # TODO: Block, BlockMaze - def _get_kwargs(maze_id: str) -> tuple: return { @@ -13,8 +11,8 @@ def _get_kwargs(maze_id: str) -> tuple: } -for maze_id in MAZE_IDS: - for i, task_cls in enumerate(TaskRegistry.REGISTRY[maze_id]): +for maze_id in TaskRegistry.keys(): + for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)): gym.envs.register( id=f"Ant{maze_id}-v{i}", entry_point="mujoco_maze.ant_maze_env:AntMazeEnv", @@ -23,8 +21,8 @@ for maze_id in MAZE_IDS: reward_threshold=task_cls.REWARD_THRESHOLD, ) -for maze_id in MAZE_IDS: - for i, task_cls in enumerate(TaskRegistry.REGISTRY[maze_id]): +for maze_id in TaskRegistry.keys(): + for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)): gym.envs.register( id=f"Point{maze_id}-v{i}", entry_point="mujoco_maze.point_maze_env:PointMazeEnv", diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index b7e4326..d70dc7d 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -142,6 +142,43 @@ class SingleGoalDenseFall(SingleGoalSparseFall): return -self.goals[0].euc_dist(obs) +class SingleGoalSparse2Rooms(MazeTask): + REWARD_THRESHOLD: float = 0.9 + + def __init__(self, scale: float) -> None: + super().__init__(scale) + self.goals = [MazeGoal(np.array([0.0, 4.0 * scale]))] + + def reward(self, obs: np.ndarray) -> float: + return 1.0 if self.termination(obs) else -0.0001 + + @staticmethod + def create_maze() -> List[List[MazeCell]]: + E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT + return [ + [B, B, B, B, B, B, B, B], + [B, R, E, E, E, E, E, B], + [B, E, E, E, E, E, E, B], + [B, B, B, B, B, E, B, B], + [B, E, E, E, E, E, E, B], + [B, E, E, E, E, E, E, B], + [B, B, B, B, B, B, B, B], + ] + + +class SingleGoalDense2Rooms(SingleGoalSparse2Rooms): + REWARD_THRESHOLD: float = 1000.0 + + def reward(self, obs: np.ndarray) -> float: + return -self.goals[0].euc_dist(obs) + + +class SubGoalSparse2Rooms(SingleGoalSparse2Rooms): + def __init__(self, scale: float) -> None: + super().__init__(scale) + self.goals.append(MazeGoal(np.array([5.0 * scale, 0.0 * scale]), 0.5, GREEN)) + + class SingleGoalSparse4Rooms(MazeTask): REWARD_THRESHOLD: float = 0.9 @@ -171,11 +208,17 @@ class SingleGoalSparse4Rooms(MazeTask): ] +class SingleGoalDense4Rooms(SingleGoalSparse4Rooms): + REWARD_THRESHOLD: float = 1000.0 + + def reward(self, obs: np.ndarray) -> float: + return -self.goals[0].euc_dist(obs) + + class SubGoalSparse4Rooms(SingleGoalSparse4Rooms): def __init__(self, scale: float) -> None: super().__init__(scale) - self.goals = [ - MazeGoal(np.array([6.0 * scale, 6.0 * scale])), + self.goals += [ MazeGoal(np.array([0.0 * scale, 6.0 * scale]), 0.5, GREEN), MazeGoal(np.array([6.0 * scale, 0.0 * scale]), 0.5, GREEN), ] @@ -183,8 +226,21 @@ class SubGoalSparse4Rooms(SingleGoalSparse4Rooms): class TaskRegistry: REGISTRY: Dict[str, List[Type[MazeTask]]] = { - "Maze": [SingleGoalDenseUMaze, SingleGoalSparseUMaze], + "UMaze": [SingleGoalDenseUMaze, SingleGoalSparseUMaze], "Push": [SingleGoalDensePush, SingleGoalSparsePush], "Fall": [SingleGoalDenseFall, SingleGoalSparseFall], - "4Rooms": [SingleGoalSparse4Rooms, SubGoalSparse4Rooms], + "2Rooms": [ + SingleGoalDense2Rooms, + SingleGoalSparse2Rooms, + SubGoalSparse2Rooms, + ], + "4Rooms": [SingleGoalSparse4Rooms, SingleGoalDense4Rooms, SubGoalSparse4Rooms], } + + @staticmethod + def keys() -> List[str]: + return list(TaskRegistry.REGISTRY.keys()) + + @staticmethod + def tasks(key: str) -> List[Type[MazeTask]]: + return TaskRegistry.REGISTRY[key] diff --git a/tests/test_envs.py b/tests/test_envs.py index d9a1df8..060be3a 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -4,7 +4,7 @@ import pytest import mujoco_maze -@pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS) +@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys()) def test_ant_maze(maze_id): env = gym.make("Ant{}-v0".format(maze_id)) assert env.reset().shape == (30,) @@ -12,7 +12,7 @@ def test_ant_maze(maze_id): assert s.shape == (30,) -@pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS) +@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys()) def test_point_maze(maze_id): env = gym.make("Point{}-v0".format(maze_id)) assert env.reset().shape == (7,)