diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index d8b4572..4ee801b 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -145,6 +145,37 @@ class DistRewardSimpleRoom(GoalRewardSimpleRoom, DistRewardMixIn): pass +class GoalRewardSquareRoom(GoalRewardUMaze): + MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 4.0, 2.0) + + def __init__(self, scale: float, goal: Tuple[float, float] = (2.0, 0.0)) -> None: + super().__init__(scale) + self.goals = [MazeGoal(np.array(goal) * scale)] + + @staticmethod + def create_maze() -> List[List[MazeCell]]: + E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT + return [ + [B, B, B, B, B], + [B, E, E, E, B], + [B, R, E, E, B], + [B, E, E, E, B], + [B, B, B, B, B], + ] + + +class NoRewardSquareRoom(GoalRewardSimpleRoom): + def __init__(self, scale: float) -> None: + super().__init__(scale) + + def reward(self) -> float: + return 0.0 + + +class DistRewardSquareRoom(GoalRewardSquareRoom, DistRewardMixIn): + pass + + class GoalRewardPush(GoalRewardUMaze): OBSERVE_BLOCKS: bool = True @@ -503,6 +534,7 @@ class BanditBilliard(SubGoalBilliard): class TaskRegistry: REGISTRY: Dict[str, List[Type[MazeTask]]] = { "SimpleRoom": [DistRewardSimpleRoom, GoalRewardSimpleRoom], + "SquareRoom": [DistRewardSquareRoom, GoalRewardSquareRoom, NoRewardSquareRoom], "UMaze": [DistRewardUMaze, GoalRewardUMaze], "Push": [DistRewardPush, GoalRewardPush], "Fall": [DistRewardFall, GoalRewardFall],