diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index 52aed2d..c8b613e 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -230,6 +230,33 @@ class NoRewardMultiPush(GoalRewardMultiPush): return 0.0 +class GoalRewardMultiPushSmall(GoalRewardMultiPush): + def __init__(self, scale: float, goal: Tuple[float, float] = (1.0, -1.0)) -> None: + super().__init__(scale, goal) + + @staticmethod + def create_maze() -> List[List[MazeCell]]: + E, B, R, M = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT, MazeCell.XY_BLOCK + return [ + [B, B, B, B, B, B], + [B, B, E, B, B, B], + [B, E, M, E, B, B], + [B, B, R, M, E, B], + [B, E, M, E, B, B], + [B, B, E, B, B, B], + [B, B, B, B, B, B], + ] + + +class DistRewardMultiPushSmall(GoalRewardMultiPushSmall, DistRewardMixIn): + pass + + +class NoRewardMultiPushSmall(GoalRewardMultiPushSmall): + def reward(self, _obs: np.ndarray) -> float: + return 0.0 + + class GoalRewardPushMaze(GoalRewardUMaze): OBSERVE_BLOCKS: bool = True MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=6.0, swimmer=None) @@ -742,6 +769,11 @@ class TaskRegistry: "UMaze": [DistRewardUMaze, GoalRewardUMaze], "Push": [DistRewardPush, GoalRewardPush], "MultiPush": [DistRewardMultiPush, GoalRewardMultiPush, NoRewardMultiPush], + "MultiPushSmall": [ + DistRewardMultiPushSmall, + GoalRewardMultiPushSmall, + NoRewardMultiPushSmall, + ], "PushMaze": [DistRewardPushMaze, GoalRewardPushMaze, NoRewardPushMaze], "Fall": [DistRewardFall, GoalRewardFall], "MultiFall": [DistRewardMultiFall, GoalRewardMultiFall, NoRewardMultiFall],