From 362bba2f79fa5a93095f5f8f50ead04342b4d6b6 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Thu, 27 May 2021 00:05:09 +0900 Subject: [PATCH] SquareRoom --- mujoco_maze/maze_task.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index d8b4572..4ee801b 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -145,6 +145,37 @@ class DistRewardSimpleRoom(GoalRewardSimpleRoom, DistRewardMixIn): pass +class GoalRewardSquareRoom(GoalRewardUMaze): + MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 4.0, 2.0) + + def __init__(self, scale: float, goal: Tuple[float, float] = (2.0, 0.0)) -> None: + super().__init__(scale) + self.goals = [MazeGoal(np.array(goal) * scale)] + + @staticmethod + def create_maze() -> List[List[MazeCell]]: + E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT + return [ + [B, B, B, B, B], + [B, E, E, E, B], + [B, R, E, E, B], + [B, E, E, E, B], + [B, B, B, B, B], + ] + + +class NoRewardSquareRoom(GoalRewardSimpleRoom): + def __init__(self, scale: float) -> None: + super().__init__(scale) + + def reward(self) -> float: + return 0.0 + + +class DistRewardSquareRoom(GoalRewardSquareRoom, DistRewardMixIn): + pass + + class GoalRewardPush(GoalRewardUMaze): OBSERVE_BLOCKS: bool = True @@ -503,6 +534,7 @@ class BanditBilliard(SubGoalBilliard): class TaskRegistry: REGISTRY: Dict[str, List[Type[MazeTask]]] = { "SimpleRoom": [DistRewardSimpleRoom, GoalRewardSimpleRoom], + "SquareRoom": [DistRewardSquareRoom, GoalRewardSquareRoom, NoRewardSquareRoom], "UMaze": [DistRewardUMaze, GoalRewardUMaze], "Push": [DistRewardPush, GoalRewardPush], "Fall": [DistRewardFall, GoalRewardFall],