diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 52df046..d9fe7e8 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -32,10 +32,11 @@ class MazeEnv(gym.Env): maze_size_scaling: float = 4.0, inner_reward_scaling: float = 1.0, restitution_coef: float = 0.8, + task_kwargs: dict = {}, *args, **kwargs, ) -> None: - self._task = maze_task(maze_size_scaling) + self._task = maze_task(maze_size_scaling, **task_kwargs) xml_path = os.path.join(MODEL_DIR, model_cls.FILE) tree = ET.parse(xml_path) diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index 74ed313..a3feda3 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -2,7 +2,7 @@ """ from abc import ABC, abstractmethod -from typing import Dict, List, NamedTuple, Type +from typing import Dict, List, NamedTuple, Tuple, Type import numpy as np @@ -237,6 +237,44 @@ class SubGoal4Rooms(GoalReward4Rooms): ] +class GoalRewardTRoom(MazeTask): + REWARD_THRESHOLD: float = 0.9 + MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 4.0) + + def __init__( + self, + scale: float, + goals: List[Tuple[float, float]] = [(2.0, -4.0)], + ) -> None: + super().__init__(scale) + self.goals = [] + for x, y in goals: + self.goals.append(MazeGoal(np.array([x * scale, y * scale]))) + + def reward(self, obs: np.ndarray) -> float: + for goal in self.goals: + if goal.neighbor(obs): + return goal.reward_scale + return -0.0001 + + @staticmethod + def create_maze() -> List[List[MazeCell]]: + E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT + return [ + [B, B, B, B, B, B, B], + [B, E, E, E, E, E, B], + [B, E, E, B, E, E, B], + [B, E, E, B, E, E, B], + [B, E, B, B, B, E, B], + [B, E, E, R, E, E, B], + [B, B, B, B, B, B, B], + ] + + +class DistRewardTRoom(GoalRewardTRoom, DistRewardMixIn): + pass + + class TaskRegistry: REGISTRY: Dict[str, List[Type[MazeTask]]] = { "UMaze": [DistRewardUMaze, GoalRewardUMaze], @@ -244,6 +282,7 @@ class TaskRegistry: "Fall": [DistRewardFall, GoalRewardFall], "2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms], "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], + "TRoom": [DistRewardTRoom, GoalRewardTRoom], } @staticmethod diff --git a/tests/test_envs.py b/tests/test_envs.py index 469ab1a..45346c4 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -20,3 +20,11 @@ def test_point_maze(maze_id): assert env.reset().shape == (7,) s, _, _, _ = env.step(env.action_space.sample()) assert s.shape == (7,) + + +@pytest.mark.parametrize("v", [0, 1]) +def test_maze_args(v): + env = gym.make(f"PointTRoom-v{v}", task_kwargs={"goals": [(-2.0, 4.0)]}) + assert env.reset().shape == (7,) + s, _, _, _ = env.step(env.action_space.sample()) + assert s.shape == (7,)