Add 2Rooms

This commit is contained in:
kngwyu 2020-06-30 01:38:02 +09:00
parent 767bd1891a
commit 142b42e34f
3 changed files with 66 additions and 12 deletions

View File

@ -2,8 +2,6 @@ import gym
from mujoco_maze.maze_task import TaskRegistry
MAZE_IDS = ["Maze", "Push", "Fall", "4Rooms"] # TODO: Block, BlockMaze
def _get_kwargs(maze_id: str) -> tuple:
return {
@ -13,8 +11,8 @@ def _get_kwargs(maze_id: str) -> tuple:
}
for maze_id in MAZE_IDS:
for i, task_cls in enumerate(TaskRegistry.REGISTRY[maze_id]):
for maze_id in TaskRegistry.keys():
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
gym.envs.register(
id=f"Ant{maze_id}-v{i}",
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
@ -23,8 +21,8 @@ for maze_id in MAZE_IDS:
reward_threshold=task_cls.REWARD_THRESHOLD,
)
for maze_id in MAZE_IDS:
for i, task_cls in enumerate(TaskRegistry.REGISTRY[maze_id]):
for maze_id in TaskRegistry.keys():
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
gym.envs.register(
id=f"Point{maze_id}-v{i}",
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",

View File

@ -142,6 +142,43 @@ class SingleGoalDenseFall(SingleGoalSparseFall):
return -self.goals[0].euc_dist(obs)
class SingleGoalSparse2Rooms(MazeTask):
REWARD_THRESHOLD: float = 0.9
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals = [MazeGoal(np.array([0.0, 4.0 * scale]))]
def reward(self, obs: np.ndarray) -> float:
return 1.0 if self.termination(obs) else -0.0001
@staticmethod
def create_maze() -> List[List[MazeCell]]:
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
return [
[B, B, B, B, B, B, B, B],
[B, R, E, E, E, E, E, B],
[B, E, E, E, E, E, E, B],
[B, B, B, B, B, E, B, B],
[B, E, E, E, E, E, E, B],
[B, E, E, E, E, E, E, B],
[B, B, B, B, B, B, B, B],
]
class SingleGoalDense2Rooms(SingleGoalSparse2Rooms):
REWARD_THRESHOLD: float = 1000.0
def reward(self, obs: np.ndarray) -> float:
return -self.goals[0].euc_dist(obs)
class SubGoalSparse2Rooms(SingleGoalSparse2Rooms):
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals.append(MazeGoal(np.array([5.0 * scale, 0.0 * scale]), 0.5, GREEN))
class SingleGoalSparse4Rooms(MazeTask):
REWARD_THRESHOLD: float = 0.9
@ -171,11 +208,17 @@ class SingleGoalSparse4Rooms(MazeTask):
]
class SingleGoalDense4Rooms(SingleGoalSparse4Rooms):
REWARD_THRESHOLD: float = 1000.0
def reward(self, obs: np.ndarray) -> float:
return -self.goals[0].euc_dist(obs)
class SubGoalSparse4Rooms(SingleGoalSparse4Rooms):
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals = [
MazeGoal(np.array([6.0 * scale, 6.0 * scale])),
self.goals += [
MazeGoal(np.array([0.0 * scale, 6.0 * scale]), 0.5, GREEN),
MazeGoal(np.array([6.0 * scale, 0.0 * scale]), 0.5, GREEN),
]
@ -183,8 +226,21 @@ class SubGoalSparse4Rooms(SingleGoalSparse4Rooms):
class TaskRegistry:
REGISTRY: Dict[str, List[Type[MazeTask]]] = {
"Maze": [SingleGoalDenseUMaze, SingleGoalSparseUMaze],
"UMaze": [SingleGoalDenseUMaze, SingleGoalSparseUMaze],
"Push": [SingleGoalDensePush, SingleGoalSparsePush],
"Fall": [SingleGoalDenseFall, SingleGoalSparseFall],
"4Rooms": [SingleGoalSparse4Rooms, SubGoalSparse4Rooms],
"2Rooms": [
SingleGoalDense2Rooms,
SingleGoalSparse2Rooms,
SubGoalSparse2Rooms,
],
"4Rooms": [SingleGoalSparse4Rooms, SingleGoalDense4Rooms, SubGoalSparse4Rooms],
}
@staticmethod
def keys() -> List[str]:
return list(TaskRegistry.REGISTRY.keys())
@staticmethod
def tasks(key: str) -> List[Type[MazeTask]]:
return TaskRegistry.REGISTRY[key]

View File

@ -4,7 +4,7 @@ import pytest
import mujoco_maze
@pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS)
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
def test_ant_maze(maze_id):
env = gym.make("Ant{}-v0".format(maze_id))
assert env.reset().shape == (30,)
@ -12,7 +12,7 @@ def test_ant_maze(maze_id):
assert s.shape == (30,)
@pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS)
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
def test_point_maze(maze_id):
env = gym.make("Point{}-v0".format(maze_id))
assert env.reset().shape == (7,)