diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index 5ba8096..6c97225 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -51,13 +51,13 @@ class Scaling(NamedTuple): class MazeTask(ABC): REWARD_THRESHOLD: float MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0) - INNER_REWARD_SCALING: float = 0.0 + INNER_REWARD_SCALING: float = 0.01 OBSERVE_BLOCKS: bool = False PUT_SPIN_NEAR_AGENT: bool = False def __init__(self, scale: float) -> None: - self.scale = scale self.goals = [] + self.scale = scale def sample_goals(self) -> bool: return False @@ -78,7 +78,16 @@ class MazeTask(ABC): pass -class SingleGoalSparseUMaze(MazeTask): +class DistRewardMixIn: + REWARD_THRESHOLD: float = -1000.0 + goals: List[MazeGoal] + scale: float + + def reward(self, obs: np.ndarray) -> float: + return -self.goals[0].euc_dist(obs) / self.scale + + +class GoalRewardUMaze(MazeTask): REWARD_THRESHOLD: float = 0.9 def __init__(self, scale: float) -> None: @@ -100,14 +109,11 @@ class SingleGoalSparseUMaze(MazeTask): ] -class SingleGoalDenseUMaze(SingleGoalSparseUMaze): - REWARD_THRESHOLD: float = 1000.0 - - def reward(self, obs: np.ndarray) -> float: - return -self.goals[0].euc_dist(obs) +class DistRewardUMaze(GoalRewardUMaze, DistRewardMixIn): + pass -class SingleGoalSparsePush(SingleGoalSparseUMaze): +class GoalRewardPush(GoalRewardUMaze): def __init__(self, scale: float) -> None: super().__init__(scale) self.goals = [MazeGoal(np.array([0.0, 2.375 * scale]))] @@ -124,14 +130,11 @@ class SingleGoalSparsePush(SingleGoalSparseUMaze): ] -class SingleGoalDensePush(SingleGoalSparsePush): - REWARD_THRESHOLD: float = 1000.0 - - def reward(self, obs: np.ndarray) -> float: - return -self.goals[0].euc_dist(obs) +class DistRewardPush(GoalRewardPush, DistRewardMixIn): + pass -class SingleGoalSparseFall(SingleGoalSparseUMaze): +class GoalRewardFall(GoalRewardUMaze): def __init__(self, scale: float) -> None: super().__init__(scale) self.goals = [MazeGoal(np.array([0.0, 3.375 * scale, 4.5]))] @@ -149,14 +152,11 @@ class SingleGoalSparseFall(SingleGoalSparseUMaze): ] -class SingleGoalDenseFall(SingleGoalSparseFall): - REWARD_THRESHOLD: float = 1000.0 - - def reward(self, obs: np.ndarray) -> float: - return -self.goals[0].euc_dist(obs) +class DistRewardFall(GoalRewardFall, DistRewardMixIn): + pass -class SingleGoalSparse2Rooms(MazeTask): +class GoalReward2Rooms(MazeTask): REWARD_THRESHOLD: float = 0.9 MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 4.0) @@ -184,20 +184,17 @@ class SingleGoalSparse2Rooms(MazeTask): ] -class SingleGoalDense2Rooms(SingleGoalSparse2Rooms): - REWARD_THRESHOLD: float = 1000.0 - - def reward(self, obs: np.ndarray) -> float: - return -self.goals[0].euc_dist(obs) +class DistReward2Rooms(GoalReward2Rooms, DistRewardMixIn): + pass -class SubGoalSparse2Rooms(SingleGoalSparse2Rooms): +class SubGoal2Rooms(GoalReward2Rooms): def __init__(self, scale: float) -> None: super().__init__(scale) self.goals.append(MazeGoal(np.array([5.0 * scale, 0.0 * scale]), 0.5, GREEN)) -class SingleGoalSparse4Rooms(MazeTask): +class GoalReward4Rooms(MazeTask): REWARD_THRESHOLD: float = 0.9 MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 4.0) @@ -227,14 +224,11 @@ class SingleGoalSparse4Rooms(MazeTask): ] -class SingleGoalDense4Rooms(SingleGoalSparse4Rooms): - REWARD_THRESHOLD: float = 1000.0 - - def reward(self, obs: np.ndarray) -> float: - return -self.goals[0].euc_dist(obs) +class DistReward4Rooms(GoalReward4Rooms, DistRewardMixIn): + pass -class SubGoalSparse4Rooms(SingleGoalSparse4Rooms): +class SubGoal4Rooms(GoalReward4Rooms): def __init__(self, scale: float) -> None: super().__init__(scale) self.goals += [ @@ -245,11 +239,11 @@ class SubGoalSparse4Rooms(SingleGoalSparse4Rooms): class TaskRegistry: REGISTRY: Dict[str, List[Type[MazeTask]]] = { - "UMaze": [SingleGoalDenseUMaze, SingleGoalSparseUMaze], - "Push": [SingleGoalDensePush, SingleGoalSparsePush], - "Fall": [SingleGoalDenseFall, SingleGoalSparseFall], - "2Rooms": [SingleGoalDense2Rooms, SingleGoalSparse2Rooms, SubGoalSparse2Rooms], - "4Rooms": [SingleGoalSparse4Rooms, SingleGoalDense4Rooms, SubGoalSparse4Rooms], + "UMaze": [DistRewardUMaze, GoalRewardUMaze], + "Push": [DistRewardPush, GoalRewardPush], + "Fall": [DistRewardFall, GoalRewardFall], + "2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms], + "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], } @staticmethod