diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index 7a425b4..c92b89a 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -225,6 +225,11 @@ class DistRewardMultiPush(GoalRewardMultiPush, DistRewardMixIn): pass +class NoRewardMultiPush(GoalRewardMultiPush): + def reward(self, _obs: np.ndarray) -> float: + return 0.0 + + class GoalRewardFall(GoalRewardUMaze): OBSERVE_BLOCKS: bool = True @@ -587,7 +592,7 @@ class TaskRegistry: "SquareRoom": [DistRewardSquareRoom, GoalRewardSquareRoom, NoRewardSquareRoom], "UMaze": [DistRewardUMaze, GoalRewardUMaze], "Push": [DistRewardPush, GoalRewardPush], - "MultiPush": [DistRewardMultiPush, GoalRewardMultiPush], + "MultiPush": [DistRewardMultiPush, GoalRewardMultiPush, NoRewardMultiPush], "Fall": [DistRewardFall, GoalRewardFall], "2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms], "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms],