Add 2Rooms
This commit is contained in:
parent
767bd1891a
commit
142b42e34f
@ -2,8 +2,6 @@ import gym
|
||||
|
||||
from mujoco_maze.maze_task import TaskRegistry
|
||||
|
||||
MAZE_IDS = ["Maze", "Push", "Fall", "4Rooms"] # TODO: Block, BlockMaze
|
||||
|
||||
|
||||
def _get_kwargs(maze_id: str) -> tuple:
|
||||
return {
|
||||
@ -13,8 +11,8 @@ def _get_kwargs(maze_id: str) -> tuple:
|
||||
}
|
||||
|
||||
|
||||
for maze_id in MAZE_IDS:
|
||||
for i, task_cls in enumerate(TaskRegistry.REGISTRY[maze_id]):
|
||||
for maze_id in TaskRegistry.keys():
|
||||
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
|
||||
gym.envs.register(
|
||||
id=f"Ant{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
|
||||
@ -23,8 +21,8 @@ for maze_id in MAZE_IDS:
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
|
||||
for maze_id in MAZE_IDS:
|
||||
for i, task_cls in enumerate(TaskRegistry.REGISTRY[maze_id]):
|
||||
for maze_id in TaskRegistry.keys():
|
||||
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
|
||||
gym.envs.register(
|
||||
id=f"Point{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
||||
|
@ -142,6 +142,43 @@ class SingleGoalDenseFall(SingleGoalSparseFall):
|
||||
return -self.goals[0].euc_dist(obs)
|
||||
|
||||
|
||||
class SingleGoalSparse2Rooms(MazeTask):
|
||||
REWARD_THRESHOLD: float = 0.9
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__(scale)
|
||||
self.goals = [MazeGoal(np.array([0.0, 4.0 * scale]))]
|
||||
|
||||
def reward(self, obs: np.ndarray) -> float:
|
||||
return 1.0 if self.termination(obs) else -0.0001
|
||||
|
||||
@staticmethod
|
||||
def create_maze() -> List[List[MazeCell]]:
|
||||
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
|
||||
return [
|
||||
[B, B, B, B, B, B, B, B],
|
||||
[B, R, E, E, E, E, E, B],
|
||||
[B, E, E, E, E, E, E, B],
|
||||
[B, B, B, B, B, E, B, B],
|
||||
[B, E, E, E, E, E, E, B],
|
||||
[B, E, E, E, E, E, E, B],
|
||||
[B, B, B, B, B, B, B, B],
|
||||
]
|
||||
|
||||
|
||||
class SingleGoalDense2Rooms(SingleGoalSparse2Rooms):
|
||||
REWARD_THRESHOLD: float = 1000.0
|
||||
|
||||
def reward(self, obs: np.ndarray) -> float:
|
||||
return -self.goals[0].euc_dist(obs)
|
||||
|
||||
|
||||
class SubGoalSparse2Rooms(SingleGoalSparse2Rooms):
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__(scale)
|
||||
self.goals.append(MazeGoal(np.array([5.0 * scale, 0.0 * scale]), 0.5, GREEN))
|
||||
|
||||
|
||||
class SingleGoalSparse4Rooms(MazeTask):
|
||||
REWARD_THRESHOLD: float = 0.9
|
||||
|
||||
@ -171,11 +208,17 @@ class SingleGoalSparse4Rooms(MazeTask):
|
||||
]
|
||||
|
||||
|
||||
class SingleGoalDense4Rooms(SingleGoalSparse4Rooms):
|
||||
REWARD_THRESHOLD: float = 1000.0
|
||||
|
||||
def reward(self, obs: np.ndarray) -> float:
|
||||
return -self.goals[0].euc_dist(obs)
|
||||
|
||||
|
||||
class SubGoalSparse4Rooms(SingleGoalSparse4Rooms):
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__(scale)
|
||||
self.goals = [
|
||||
MazeGoal(np.array([6.0 * scale, 6.0 * scale])),
|
||||
self.goals += [
|
||||
MazeGoal(np.array([0.0 * scale, 6.0 * scale]), 0.5, GREEN),
|
||||
MazeGoal(np.array([6.0 * scale, 0.0 * scale]), 0.5, GREEN),
|
||||
]
|
||||
@ -183,8 +226,21 @@ class SubGoalSparse4Rooms(SingleGoalSparse4Rooms):
|
||||
|
||||
class TaskRegistry:
|
||||
REGISTRY: Dict[str, List[Type[MazeTask]]] = {
|
||||
"Maze": [SingleGoalDenseUMaze, SingleGoalSparseUMaze],
|
||||
"UMaze": [SingleGoalDenseUMaze, SingleGoalSparseUMaze],
|
||||
"Push": [SingleGoalDensePush, SingleGoalSparsePush],
|
||||
"Fall": [SingleGoalDenseFall, SingleGoalSparseFall],
|
||||
"4Rooms": [SingleGoalSparse4Rooms, SubGoalSparse4Rooms],
|
||||
"2Rooms": [
|
||||
SingleGoalDense2Rooms,
|
||||
SingleGoalSparse2Rooms,
|
||||
SubGoalSparse2Rooms,
|
||||
],
|
||||
"4Rooms": [SingleGoalSparse4Rooms, SingleGoalDense4Rooms, SubGoalSparse4Rooms],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def keys() -> List[str]:
|
||||
return list(TaskRegistry.REGISTRY.keys())
|
||||
|
||||
@staticmethod
|
||||
def tasks(key: str) -> List[Type[MazeTask]]:
|
||||
return TaskRegistry.REGISTRY[key]
|
||||
|
@ -4,7 +4,7 @@ import pytest
|
||||
import mujoco_maze
|
||||
|
||||
|
||||
@pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS)
|
||||
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
|
||||
def test_ant_maze(maze_id):
|
||||
env = gym.make("Ant{}-v0".format(maze_id))
|
||||
assert env.reset().shape == (30,)
|
||||
@ -12,7 +12,7 @@ def test_ant_maze(maze_id):
|
||||
assert s.shape == (30,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS)
|
||||
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
|
||||
def test_point_maze(maze_id):
|
||||
env = gym.make("Point{}-v0".format(maze_id))
|
||||
assert env.reset().shape == (7,)
|
||||
|
Loading…
Reference in New Issue
Block a user