From 98fe8b977e5937cadb733ad9dc0a7aac24f43fd4 Mon Sep 17 00:00:00 2001 From: Yuji Kanagawa Date: Sat, 2 Oct 2021 17:12:40 +0900 Subject: [PATCH] Ant Billiard --- mujoco_maze/maze_env.py | 5 +++-- mujoco_maze/maze_task.py | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 5537729..ce8ad7b 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -75,7 +75,6 @@ class MazeEnv(gym.Env): torso_y, model_cls.RADIUS, ) - # Now all object balls have size=1.0 self._objball_collision = maze_env_utils.CollisionDetector( structure, size_scaling, @@ -490,7 +489,8 @@ def _add_object_ball( "joint", name=f"objball_{i}_{j}_x", axis="1 0 0", - pos="0 0 0.0", + pos="0 0 0", + range="-1 1", type="slide", ) ET.SubElement( @@ -499,6 +499,7 @@ def _add_object_ball( name=f"objball_{i}_{j}_y", axis="0 1 0", pos="0 0 0", + range="-1 1", type="slide", ) ET.SubElement( diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index 6a8a3b8..fac467d 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -256,7 +256,7 @@ class DistRewardFall(GoalRewardFall, DistRewardMixIn): class GoalRewardMultiFall(GoalRewardUMaze): - MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=6.0, swimmer=None) + MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=None, swimmer=None) OBSERVE_BLOCKS: bool = True def __init__(self, scale: float, goal: Tuple[int, int] = (0.0, 4.0)) -> None: @@ -674,6 +674,36 @@ class BanditBilliard(SubGoalBilliard): ] +class GoalRewardSmallBilliard(GoalRewardBilliard): + MAZE_SIZE_SCALING: Scaling = Scaling(ant=1.5, point=4.0, swimmer=None) + OBJECT_BALL_SIZE: float = 0.5 + + def __init__(self, scale: float, goal: Tuple[float, float] = (-21.0, -2.0)) -> None: + super().__init__(scale, goal) + + @staticmethod + def create_maze() -> List[List[MazeCell]]: + E, B = MazeCell.EMPTY, MazeCell.BLOCK + R, M = MazeCell.ROBOT, MazeCell.OBJECT_BALL + return [ + [B, B, B, B, B], + [B, E, E, E, B], + [B, E, M, E, B], + [B, E, R, E, B], + [B, E, E, E, B], + [B, B, B, B, B], + ] + + +class DistRewardSmallBilliard(GoalRewardSmallBilliard, DistRewardMixIn): + pass + + +class NoRewardSmallBilliard(GoalRewardSmallBilliard): + def reward(self, _obs: np.ndarray) -> float: + return 0.0 + + class TaskRegistry: REGISTRY: Dict[str, List[Type[MazeTask]]] = { "SimpleRoom": [DistRewardSimpleRoom, GoalRewardSimpleRoom], @@ -697,6 +727,11 @@ class TaskRegistry: BanditBilliard, # v3 NoRewardBilliard, # v4 ], + "SmallBilliard": [ + DistRewardSmallBilliard, + GoalRewardSmallBilliard, + NoRewardSmallBilliard, + ], } @staticmethod