Ant Billiard
This commit is contained in:
parent
06cf7c9b8b
commit
98fe8b977e
@ -75,7 +75,6 @@ class MazeEnv(gym.Env):
|
|||||||
torso_y,
|
torso_y,
|
||||||
model_cls.RADIUS,
|
model_cls.RADIUS,
|
||||||
)
|
)
|
||||||
# Now all object balls have size=1.0
|
|
||||||
self._objball_collision = maze_env_utils.CollisionDetector(
|
self._objball_collision = maze_env_utils.CollisionDetector(
|
||||||
structure,
|
structure,
|
||||||
size_scaling,
|
size_scaling,
|
||||||
@ -490,7 +489,8 @@ def _add_object_ball(
|
|||||||
"joint",
|
"joint",
|
||||||
name=f"objball_{i}_{j}_x",
|
name=f"objball_{i}_{j}_x",
|
||||||
axis="1 0 0",
|
axis="1 0 0",
|
||||||
pos="0 0 0.0",
|
pos="0 0 0",
|
||||||
|
range="-1 1",
|
||||||
type="slide",
|
type="slide",
|
||||||
)
|
)
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
@ -499,6 +499,7 @@ def _add_object_ball(
|
|||||||
name=f"objball_{i}_{j}_y",
|
name=f"objball_{i}_{j}_y",
|
||||||
axis="0 1 0",
|
axis="0 1 0",
|
||||||
pos="0 0 0",
|
pos="0 0 0",
|
||||||
|
range="-1 1",
|
||||||
type="slide",
|
type="slide",
|
||||||
)
|
)
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
|
@ -256,7 +256,7 @@ class DistRewardFall(GoalRewardFall, DistRewardMixIn):
|
|||||||
|
|
||||||
|
|
||||||
class GoalRewardMultiFall(GoalRewardUMaze):
|
class GoalRewardMultiFall(GoalRewardUMaze):
|
||||||
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=6.0, swimmer=None)
|
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=None, swimmer=None)
|
||||||
OBSERVE_BLOCKS: bool = True
|
OBSERVE_BLOCKS: bool = True
|
||||||
|
|
||||||
def __init__(self, scale: float, goal: Tuple[int, int] = (0.0, 4.0)) -> None:
|
def __init__(self, scale: float, goal: Tuple[int, int] = (0.0, 4.0)) -> None:
|
||||||
@ -674,6 +674,36 @@ class BanditBilliard(SubGoalBilliard):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GoalRewardSmallBilliard(GoalRewardBilliard):
|
||||||
|
MAZE_SIZE_SCALING: Scaling = Scaling(ant=1.5, point=4.0, swimmer=None)
|
||||||
|
OBJECT_BALL_SIZE: float = 0.5
|
||||||
|
|
||||||
|
def __init__(self, scale: float, goal: Tuple[float, float] = (-21.0, -2.0)) -> None:
|
||||||
|
super().__init__(scale, goal)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_maze() -> List[List[MazeCell]]:
|
||||||
|
E, B = MazeCell.EMPTY, MazeCell.BLOCK
|
||||||
|
R, M = MazeCell.ROBOT, MazeCell.OBJECT_BALL
|
||||||
|
return [
|
||||||
|
[B, B, B, B, B],
|
||||||
|
[B, E, E, E, B],
|
||||||
|
[B, E, M, E, B],
|
||||||
|
[B, E, R, E, B],
|
||||||
|
[B, E, E, E, B],
|
||||||
|
[B, B, B, B, B],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DistRewardSmallBilliard(GoalRewardSmallBilliard, DistRewardMixIn):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoRewardSmallBilliard(GoalRewardSmallBilliard):
|
||||||
|
def reward(self, _obs: np.ndarray) -> float:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
class TaskRegistry:
|
class TaskRegistry:
|
||||||
REGISTRY: Dict[str, List[Type[MazeTask]]] = {
|
REGISTRY: Dict[str, List[Type[MazeTask]]] = {
|
||||||
"SimpleRoom": [DistRewardSimpleRoom, GoalRewardSimpleRoom],
|
"SimpleRoom": [DistRewardSimpleRoom, GoalRewardSimpleRoom],
|
||||||
@ -697,6 +727,11 @@ class TaskRegistry:
|
|||||||
BanditBilliard, # v3
|
BanditBilliard, # v3
|
||||||
NoRewardBilliard, # v4
|
NoRewardBilliard, # v4
|
||||||
],
|
],
|
||||||
|
"SmallBilliard": [
|
||||||
|
DistRewardSmallBilliard,
|
||||||
|
GoalRewardSmallBilliard,
|
||||||
|
NoRewardSmallBilliard,
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user