Add 2Rooms
This commit is contained in:
parent
767bd1891a
commit
142b42e34f
@ -2,8 +2,6 @@ import gym
|
|||||||
|
|
||||||
from mujoco_maze.maze_task import TaskRegistry
|
from mujoco_maze.maze_task import TaskRegistry
|
||||||
|
|
||||||
MAZE_IDS = ["Maze", "Push", "Fall", "4Rooms"] # TODO: Block, BlockMaze
|
|
||||||
|
|
||||||
|
|
||||||
def _get_kwargs(maze_id: str) -> tuple:
|
def _get_kwargs(maze_id: str) -> tuple:
|
||||||
return {
|
return {
|
||||||
@ -13,8 +11,8 @@ def _get_kwargs(maze_id: str) -> tuple:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
for maze_id in MAZE_IDS:
|
for maze_id in TaskRegistry.keys():
|
||||||
for i, task_cls in enumerate(TaskRegistry.REGISTRY[maze_id]):
|
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
|
||||||
gym.envs.register(
|
gym.envs.register(
|
||||||
id=f"Ant{maze_id}-v{i}",
|
id=f"Ant{maze_id}-v{i}",
|
||||||
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
|
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
|
||||||
@ -23,8 +21,8 @@ for maze_id in MAZE_IDS:
|
|||||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||||
)
|
)
|
||||||
|
|
||||||
for maze_id in MAZE_IDS:
|
for maze_id in TaskRegistry.keys():
|
||||||
for i, task_cls in enumerate(TaskRegistry.REGISTRY[maze_id]):
|
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
|
||||||
gym.envs.register(
|
gym.envs.register(
|
||||||
id=f"Point{maze_id}-v{i}",
|
id=f"Point{maze_id}-v{i}",
|
||||||
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
||||||
|
@ -142,6 +142,43 @@ class SingleGoalDenseFall(SingleGoalSparseFall):
|
|||||||
return -self.goals[0].euc_dist(obs)
|
return -self.goals[0].euc_dist(obs)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleGoalSparse2Rooms(MazeTask):
|
||||||
|
REWARD_THRESHOLD: float = 0.9
|
||||||
|
|
||||||
|
def __init__(self, scale: float) -> None:
|
||||||
|
super().__init__(scale)
|
||||||
|
self.goals = [MazeGoal(np.array([0.0, 4.0 * scale]))]
|
||||||
|
|
||||||
|
def reward(self, obs: np.ndarray) -> float:
|
||||||
|
return 1.0 if self.termination(obs) else -0.0001
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_maze() -> List[List[MazeCell]]:
|
||||||
|
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
|
||||||
|
return [
|
||||||
|
[B, B, B, B, B, B, B, B],
|
||||||
|
[B, R, E, E, E, E, E, B],
|
||||||
|
[B, E, E, E, E, E, E, B],
|
||||||
|
[B, B, B, B, B, E, B, B],
|
||||||
|
[B, E, E, E, E, E, E, B],
|
||||||
|
[B, E, E, E, E, E, E, B],
|
||||||
|
[B, B, B, B, B, B, B, B],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SingleGoalDense2Rooms(SingleGoalSparse2Rooms):
|
||||||
|
REWARD_THRESHOLD: float = 1000.0
|
||||||
|
|
||||||
|
def reward(self, obs: np.ndarray) -> float:
|
||||||
|
return -self.goals[0].euc_dist(obs)
|
||||||
|
|
||||||
|
|
||||||
|
class SubGoalSparse2Rooms(SingleGoalSparse2Rooms):
|
||||||
|
def __init__(self, scale: float) -> None:
|
||||||
|
super().__init__(scale)
|
||||||
|
self.goals.append(MazeGoal(np.array([5.0 * scale, 0.0 * scale]), 0.5, GREEN))
|
||||||
|
|
||||||
|
|
||||||
class SingleGoalSparse4Rooms(MazeTask):
|
class SingleGoalSparse4Rooms(MazeTask):
|
||||||
REWARD_THRESHOLD: float = 0.9
|
REWARD_THRESHOLD: float = 0.9
|
||||||
|
|
||||||
@ -171,11 +208,17 @@ class SingleGoalSparse4Rooms(MazeTask):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SingleGoalDense4Rooms(SingleGoalSparse4Rooms):
|
||||||
|
REWARD_THRESHOLD: float = 1000.0
|
||||||
|
|
||||||
|
def reward(self, obs: np.ndarray) -> float:
|
||||||
|
return -self.goals[0].euc_dist(obs)
|
||||||
|
|
||||||
|
|
||||||
class SubGoalSparse4Rooms(SingleGoalSparse4Rooms):
|
class SubGoalSparse4Rooms(SingleGoalSparse4Rooms):
|
||||||
def __init__(self, scale: float) -> None:
|
def __init__(self, scale: float) -> None:
|
||||||
super().__init__(scale)
|
super().__init__(scale)
|
||||||
self.goals = [
|
self.goals += [
|
||||||
MazeGoal(np.array([6.0 * scale, 6.0 * scale])),
|
|
||||||
MazeGoal(np.array([0.0 * scale, 6.0 * scale]), 0.5, GREEN),
|
MazeGoal(np.array([0.0 * scale, 6.0 * scale]), 0.5, GREEN),
|
||||||
MazeGoal(np.array([6.0 * scale, 0.0 * scale]), 0.5, GREEN),
|
MazeGoal(np.array([6.0 * scale, 0.0 * scale]), 0.5, GREEN),
|
||||||
]
|
]
|
||||||
@ -183,8 +226,21 @@ class SubGoalSparse4Rooms(SingleGoalSparse4Rooms):
|
|||||||
|
|
||||||
class TaskRegistry:
|
class TaskRegistry:
|
||||||
REGISTRY: Dict[str, List[Type[MazeTask]]] = {
|
REGISTRY: Dict[str, List[Type[MazeTask]]] = {
|
||||||
"Maze": [SingleGoalDenseUMaze, SingleGoalSparseUMaze],
|
"UMaze": [SingleGoalDenseUMaze, SingleGoalSparseUMaze],
|
||||||
"Push": [SingleGoalDensePush, SingleGoalSparsePush],
|
"Push": [SingleGoalDensePush, SingleGoalSparsePush],
|
||||||
"Fall": [SingleGoalDenseFall, SingleGoalSparseFall],
|
"Fall": [SingleGoalDenseFall, SingleGoalSparseFall],
|
||||||
"4Rooms": [SingleGoalSparse4Rooms, SubGoalSparse4Rooms],
|
"2Rooms": [
|
||||||
|
SingleGoalDense2Rooms,
|
||||||
|
SingleGoalSparse2Rooms,
|
||||||
|
SubGoalSparse2Rooms,
|
||||||
|
],
|
||||||
|
"4Rooms": [SingleGoalSparse4Rooms, SingleGoalDense4Rooms, SubGoalSparse4Rooms],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def keys() -> List[str]:
|
||||||
|
return list(TaskRegistry.REGISTRY.keys())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tasks(key: str) -> List[Type[MazeTask]]:
|
||||||
|
return TaskRegistry.REGISTRY[key]
|
||||||
|
@ -4,7 +4,7 @@ import pytest
|
|||||||
import mujoco_maze
|
import mujoco_maze
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS)
|
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
|
||||||
def test_ant_maze(maze_id):
|
def test_ant_maze(maze_id):
|
||||||
env = gym.make("Ant{}-v0".format(maze_id))
|
env = gym.make("Ant{}-v0".format(maze_id))
|
||||||
assert env.reset().shape == (30,)
|
assert env.reset().shape == (30,)
|
||||||
@ -12,7 +12,7 @@ def test_ant_maze(maze_id):
|
|||||||
assert s.shape == (30,)
|
assert s.shape == (30,)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS)
|
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
|
||||||
def test_point_maze(maze_id):
|
def test_point_maze(maze_id):
|
||||||
env = gym.make("Point{}-v0".format(maze_id))
|
env = gym.make("Point{}-v0".format(maze_id))
|
||||||
assert env.reset().shape == (7,)
|
assert env.reset().shape == (7,)
|
||||||
|
Loading…
Reference in New Issue
Block a user