diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index 271b484..c626c0a 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -396,19 +396,30 @@ class SubGoalBilliard(GoalRewardBilliard): def __init__( self, scale: float, - primary_goal: Tuple[float, float] = (4.0, -2.0), - subgoal: Tuple[float, float] = (4.0, 2.0), + primary_goal: Tuple[float, float] = (2.0, -3.0), + subgoals: List[Tuple[float, float]] = [(-2.0, -3.0), (-2.0, 1.0), (2.0, 1.0)], ) -> None: super().__init__(scale, primary_goal) - self.goals.append( - MazeGoal( - np.array(subgoal) * scale, - reward_scale=0.5, - rgb=GREEN, - threshold=self._threshold(), - custom_size=self.GOAL_SIZE, + for subgoal in subgoals: + self.goals.append( + MazeGoal( + np.array(subgoal) * scale, + reward_scale=0.5, + rgb=GREEN, + threshold=self._threshold(), + custom_size=self.GOAL_SIZE, + ) ) - ) + + +class BanditBilliard(SubGoalBilliard): + def __init__( + self, + scale: float, + primary_goal: Tuple[float, float] = (4.0, -2.0), + subgoals: List[Tuple[float, float]] = [(4.0, 2.0)], + ) -> None: + super().__init__(scale, primary_goal, subgoals) @staticmethod def create_maze() -> List[List[MazeCell]]: @@ -435,7 +446,12 @@ class TaskRegistry: "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], "TRoom": [DistRewardTRoom, GoalRewardTRoom], "BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze], - "Billiard": [DistRewardBilliard, GoalRewardBilliard, SubGoalBilliard], + "Billiard": [ + DistRewardBilliard, + GoalRewardBilliard, + SubGoalBilliard, + BanditBilliard, + ], } @staticmethod