diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 0ed16a0..7619c35 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -167,7 +167,7 @@ class MazeEnv(gym.Env): name=f"goal_site{i}", pos=f"{goal.pos[0]} {goal.pos[1]} {z}", size=f"{maze_size_scaling * 0.1}", - rgba=goal.rbga_str(), + rgba=goal.rgb.rgba_str(), ) _, file_path = tempfile.mkstemp(text=True, suffix=".xml") @@ -420,7 +420,7 @@ def _add_object_ball( name=f"objball_{i}_{j}_geom", size=f"{size}", # Radius pos=f"0.0 0.0 {size}", # Z = size so that this ball can move!! - rgba="0.1 0.1 0.7 1", + rgba=maze_task.BLUE.rgba_str(), contype="1", conaffinity="1", solimp="0.9 0.99 0.001", diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index 002e943..662d594 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -14,6 +14,9 @@ class Rgb(NamedTuple): green: float blue: float + def rgba_str(self) -> str: + return f"{self.red} {self.green} {self.blue} 1" + RED = Rgb(0.7, 0.1, 0.1) GREEN = Rgb(0.1, 0.7, 0.1) @@ -37,10 +40,6 @@ class MazeGoal: self.threshold = threshold self.custom_size = custom_size - def rbga_str(self) -> str: - r, g, b = self.rgb - return f"{r} {g} {b} 1" - def neighbor(self, obs: np.ndarray) -> float: return np.linalg.norm(obs[: self.dim] - self.pos) <= self.threshold @@ -339,19 +338,31 @@ class GoalRewardBilliard(MazeTask): PENALTY: float = -0.0001 MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 3.0, 3.0) OBSERVE_BALLS: bool = True + GOAL_SIZE: float = 0.3 def __init__(self, scale: float, goal: Tuple[float, float] = (2.0, -3.0)) -> None: super().__init__(scale) goal = np.array(goal) * scale - goal_size = 0.3 - threshold = goal_size + self.OBJECT_BALL_SIZE - self.goals = [MazeGoal(goal, threshold=threshold, custom_size=goal_size)] + self.goals.append( + MazeGoal(goal, threshold=self._threshold(), custom_size=self.GOAL_SIZE) + ) + + def _threshold(self) -> float: + return self.OBJECT_BALL_SIZE + self.GOAL_SIZE def reward(self, obs: np.ndarray) -> float: - return 1.0 if self.termination(obs) else self.PENALTY + object_pos = obs[3:6] + for goal in self.goals: + if goal.neighbor(object_pos): + return goal.reward_scale + return self.PENALTY def termination(self, obs: np.ndarray) -> bool: - return super().termination(obs[3:6]) + object_pos = obs[3:6] + for goal in self.goals: + if goal.neighbor(object_pos): + return True + return False @staticmethod def create_maze() -> List[List[MazeCell]]: @@ -372,6 +383,38 @@ class DistRewardBilliard(GoalRewardBilliard): return -self.goals[0].euc_dist(obs[3:6]) / self.scale +class SubGoalBilliard(GoalRewardBilliard): + def __init__( + self, + scale: float, + primary_goal: Tuple[float, float] = (2.0, -3.0), + subgoal: Tuple[float, float] = (-2.0, -3.0), + ) -> None: + super().__init__(scale, primary_goal) + self.goals.append( + MazeGoal( + np.array(subgoal) * scale, + reward_scale=0.5, + rgb=GREEN, + threshold=self._threshold(), + custom_size=self.GOAL_SIZE, + ) + ) + + @staticmethod + def create_maze() -> List[List[MazeCell]]: + E, B = MazeCell.EMPTY, MazeCell.BLOCK + R, M = MazeCell.ROBOT, MazeCell.OBJECT_BALL + return [ + [B, B, B, B, B, B, B], + [B, E, E, E, E, E, B], + [B, E, E, E, B, B, B], + [B, E, E, M, E, E, B], + [B, E, E, R, E, E, B], + [B, B, B, B, B, B, B], + ] + + class TaskRegistry: REGISTRY: Dict[str, List[Type[MazeTask]]] = { "SimpleRoom": [DistRewardSimpleRoom, GoalRewardSimpleRoom], @@ -382,7 +425,7 @@ class TaskRegistry: "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], "TRoom": [DistRewardTRoom, GoalRewardTRoom], "BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze], - "Billiard": [DistRewardBilliard, GoalRewardBilliard], + "Billiard": [DistRewardBilliard, GoalRewardBilliard, SubGoalBilliard], } @staticmethod diff --git a/tests/test_envs.py b/tests/test_envs.py index d769fd8..92dc9e8 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -36,6 +36,20 @@ def test_point_maze(maze_id): assert r < 0.0 +@pytest.mark.parametrize("maze_id", ["2Rooms", "4Rooms", "Billiard"]) +def test_subgoal_envs(maze_id): + env = gym.make(f"Point{maze_id}-v2") + s0 = env.reset() + s, r, _, _ = env.step(env.action_space.sample()) + if not env.unwrapped.has_extended_obs: + assert s0.shape == (7,) + assert s.shape == (7,) + elif env.unwrapped._observe_balls: + assert s0.shape == (10,) + assert s.shape == (10,) + assert len(env.unwrapped._task.goals) > 1 + + @pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys()) def test_reacher_maze(maze_id): for inhibited in ["Fall", "Push", "Block", "Billiard"]: