From 3b8ec64c267f6201cc0b6a26fa595279287a0ddf Mon Sep 17 00:00:00 2001 From: kngwyu Date: Thu, 1 Oct 2020 01:31:29 +0900 Subject: [PATCH] Add SubGoalTRoom --- mujoco_maze/maze_task.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index c626c0a..21f4895 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -316,6 +316,19 @@ class DistRewardTRoom(GoalRewardTRoom, DistRewardMixIn): pass +class SubGoalTRoom(GoalRewardTRoom): + def __init__( + self, + scale: float, + primary_goal: Tuple[float, float] = (2.0, -3.0), + subgoal: Tuple[float, float] = (-2.0, -3.0), + ) -> None: + super().__init__(scale, primary_goal) + self.goals.append( + MazeGoal(np.array(subgoal) * scale, reward_scale=0.5, rgb=GREEN) + ) + + class GoalRewardBlockMaze(GoalRewardUMaze): OBSERVE_BLOCKS: bool = True @@ -444,7 +457,7 @@ class TaskRegistry: "Fall": [DistRewardFall, GoalRewardFall], "2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms], "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], - "TRoom": [DistRewardTRoom, GoalRewardTRoom], + "TRoom": [DistRewardTRoom, GoalRewardTRoom, SubGoalTRoom], "BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze], "Billiard": [ DistRewardBilliard,