From 28711cee1921529b618f119197ef6137fc71d3fa Mon Sep 17 00:00:00 2001 From: kngwyu Date: Thu, 1 Oct 2020 00:20:48 +0900 Subject: [PATCH] Introduce BanditBilliard and change SubGoalBilliard to a more normal one --- mujoco_maze/maze_task.py | 38 +++++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index 271b484..c626c0a 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -396,19 +396,30 @@ class SubGoalBilliard(GoalRewardBilliard): def __init__( self, scale: float, - primary_goal: Tuple[float, float] = (4.0, -2.0), - subgoal: Tuple[float, float] = (4.0, 2.0), + primary_goal: Tuple[float, float] = (2.0, -3.0), + subgoals: List[Tuple[float, float]] = [(-2.0, -3.0), (-2.0, 1.0), (2.0, 1.0)], ) -> None: super().__init__(scale, primary_goal) - self.goals.append( - MazeGoal( - np.array(subgoal) * scale, - reward_scale=0.5, - rgb=GREEN, - threshold=self._threshold(), - custom_size=self.GOAL_SIZE, + for subgoal in subgoals: + self.goals.append( + MazeGoal( + np.array(subgoal) * scale, + reward_scale=0.5, + rgb=GREEN, + threshold=self._threshold(), + custom_size=self.GOAL_SIZE, + ) ) - ) + + +class BanditBilliard(SubGoalBilliard): + def __init__( + self, + scale: float, + primary_goal: Tuple[float, float] = (4.0, -2.0), + subgoals: List[Tuple[float, float]] = [(4.0, 2.0)], + ) -> None: + super().__init__(scale, primary_goal, subgoals) @staticmethod def create_maze() -> List[List[MazeCell]]: @@ -435,7 +446,12 @@ class TaskRegistry: "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], "TRoom": [DistRewardTRoom, GoalRewardTRoom], "BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze], - "Billiard": [DistRewardBilliard, GoalRewardBilliard, SubGoalBilliard], + "Billiard": [ + DistRewardBilliard, + GoalRewardBilliard, + SubGoalBilliard, + BanditBilliard, + ], } @staticmethod