From 06cf7c9b8b4ca426b9b777cd7748fcaf34a1151e Mon Sep 17 00:00:00 2001 From: Yuji Kanagawa Date: Sat, 2 Oct 2021 13:15:00 +0900 Subject: [PATCH] BlockCarry --- mujoco_maze/maze_task.py | 57 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index 31c21c9..6a8a3b8 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -515,6 +515,62 @@ class DistRewardBlockMaze(GoalRewardBlockMaze, DistRewardMixIn): pass +class GoalRewardBlockCarry(MazeTask): + REWARD_THRESHOLD: float = 0.9 + PENALTY: float = -0.0001 + MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=3.0, swimmer=None) + OBSERVE_BLOCKS: bool = True + GOAL_SIZE: float = 0.3 + + def __init__(self, scale: float, goal: Tuple[float, float] = (2.0, 0.0)) -> None: + super().__init__(scale) + self.goals.append( + MazeGoal( + np.array(goal) * scale, + threshold=self.GOAL_SIZE + 0.5, + custom_size=self.GOAL_SIZE, + ) + ) + + def reward(self, obs: np.ndarray) -> float: + object_pos = obs[3:6] + for goal in self.goals: + if goal.neighbor(object_pos): + return goal.reward_scale + return self.PENALTY + + def termination(self, obs: np.ndarray) -> bool: + object_pos = obs[3:6] + for goal in self.goals: + if goal.neighbor(object_pos): + return True + return False + + @staticmethod + def create_maze() -> List[List[MazeCell]]: + E, B = MazeCell.EMPTY, MazeCell.BLOCK + R, M = MazeCell.ROBOT, MazeCell.XY_BLOCK + return [ + [B, B, B, B, B], + [B, E, E, E, B], + [B, E, E, E, B], + [B, R, M, E, B], + [B, E, E, E, B], + [B, E, E, E, B], + [B, B, B, B, B], + ] + + +class DistRewardBlockCarry(GoalRewardBlockCarry): + def reward(self, obs: np.ndarray) -> float: + return -self.goals[0].euc_dist(obs[3:6]) / self.scale + + +class NoRewardBlockCarry(GoalRewardBlockCarry): + def reward(self, _obs: np.ndarray) -> float: + return 0.0 + + class GoalRewardBilliard(MazeTask): REWARD_THRESHOLD: float = 0.9 PENALTY: float = -0.0001 @@ -633,6 +689,7 @@ class TaskRegistry: "BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze], "Corridor": [DistRewardCorridor, GoalRewardCorridor, NoRewardCorridor], "LongCorridor": [DistRewardLongCorridor, GoalRewardLongCorridor], + "BlockCarry": [DistRewardBlockCarry, GoalRewardBlockCarry, NoRewardBlockCarry], "Billiard": [ DistRewardBilliard, # v0 GoalRewardBilliard, # v1