From 71f89461ee56868c1fc9252065f119d0825bc92a Mon Sep 17 00:00:00 2001 From: Yuji Kanagawa Date: Tue, 28 Sep 2021 16:42:22 +0900 Subject: [PATCH] MultiFall --- mujoco_maze/maze_task.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index aca6342..a84de2f 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -255,6 +255,38 @@ class DistRewardFall(GoalRewardFall, DistRewardMixIn): pass +class GoalRewardMultiFall(GoalRewardUMaze): + MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=6.0, swimmer=None) + OBSERVE_BLOCKS: bool = True + + def __init__(self, scale: float, goal: Tuple[int, int] = (0.0, 3.0)) -> None: + super().__init__(scale) + self.goals = [MazeGoal(np.array([*goal, 0.5]) * scale)] + + @staticmethod + def create_maze() -> List[List[MazeCell]]: + E, B, C, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.CHASM, MazeCell.ROBOT + M = MazeCell.YZ_BLOCK + return [ + [B, B, B, B, B, B, B, B, B], + [B, E, C, E, E, E, C, E, B], + [B, E, C, E, R, E, C, E, B], + [B, E, C, E, M, E, C, E, B], + [B, B, B, C, C, C, B, B, B], + [B, B, B, E, E, E, B, B, B], + [B, B, B, B, B, B, B, B, B], + ] + + +class DistRewardMultiFall(GoalRewardMultiFall, DistRewardMixIn): + pass + + +class NoRewardMultiFall(GoalRewardFall): + def reward(self, _obs: np.ndarray) -> float: + return 0.0 + + class GoalReward2Rooms(MazeTask): REWARD_THRESHOLD: float = 0.9 PENALTY: float = -0.0001 @@ -594,6 +626,7 @@ class TaskRegistry: "Push": [DistRewardPush, GoalRewardPush], "MultiPush": [DistRewardMultiPush, GoalRewardMultiPush, NoRewardMultiPush], "Fall": [DistRewardFall, GoalRewardFall], + "MultiFall": [DistRewardMultiFall, GoalRewardMultiFall, NoRewardMultiFall], "2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms], "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], "TRoom": [DistRewardTRoom, GoalRewardTRoom, SubGoalTRoom],