AntBilliard
This commit is contained in:
parent
30c845fb0c
commit
cd55df02b1
@ -13,6 +13,7 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
|
||||
MANUAL_COLLISION: bool
|
||||
ORI_IND: Optional[int] = None
|
||||
RADIUS: Optional[float] = None
|
||||
OBJBALL_TYPE: Optional[str] = None
|
||||
|
||||
def __init__(self, file_path: str, frame_skip: int) -> None:
|
||||
MujocoEnv.__init__(self, file_path, frame_skip)
|
||||
|
@ -39,6 +39,7 @@ class AntEnv(AgentModel):
|
||||
FILE: str = "ant.xml"
|
||||
ORI_IND: int = 3
|
||||
MANUAL_COLLISION: bool = False
|
||||
OBJBALL_TYPE: str = "freejoint"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -165,7 +165,28 @@ class MazeEnv(gym.Env):
|
||||
elif struct.is_object_ball():
|
||||
# Movable Ball
|
||||
self.object_balls.append(f"objball_{i}_{j}")
|
||||
_add_object_ball(worldbody, i, j, x, y, self._task.OBJECT_BALL_SIZE)
|
||||
if model_cls.OBJBALL_TYPE == "hinge":
|
||||
_add_objball_hinge(
|
||||
worldbody,
|
||||
i,
|
||||
j,
|
||||
x,
|
||||
y,
|
||||
self._task.OBJECT_BALL_SIZE,
|
||||
)
|
||||
elif model_cls.OBJBALL_TYPE == "freejoint":
|
||||
_add_objball_freejoint(
|
||||
worldbody,
|
||||
i,
|
||||
j,
|
||||
x,
|
||||
y,
|
||||
self._task.OBJECT_BALL_SIZE,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"OBJBALL_TYPE is not registered for {model_cls}"
|
||||
)
|
||||
|
||||
torso = tree.find(".//body[@name='torso']")
|
||||
geoms = torso.findall(".//geom")
|
||||
@ -185,7 +206,7 @@ class MazeEnv(gym.Env):
|
||||
"site",
|
||||
name=f"goal_site{i}",
|
||||
pos=f"{goal.pos[0]} {goal.pos[1]} {z}",
|
||||
size=f"{maze_size_scaling * 0.1}",
|
||||
size=size,
|
||||
rgba=goal.rgb.rgba_str(),
|
||||
)
|
||||
|
||||
@ -461,7 +482,7 @@ class MazeEnv(gym.Env):
|
||||
self._websock_server_pipe.send(None)
|
||||
|
||||
|
||||
def _add_object_ball(
|
||||
def _add_objball_hinge(
|
||||
worldbody: ET.Element,
|
||||
i: str,
|
||||
j: str,
|
||||
@ -490,7 +511,6 @@ def _add_object_ball(
|
||||
name=f"objball_{i}_{j}_x",
|
||||
axis="1 0 0",
|
||||
pos="0 0 0",
|
||||
range="-1 1",
|
||||
type="slide",
|
||||
)
|
||||
ET.SubElement(
|
||||
@ -499,7 +519,6 @@ def _add_object_ball(
|
||||
name=f"objball_{i}_{j}_y",
|
||||
axis="0 1 0",
|
||||
pos="0 0 0",
|
||||
range="-1 1",
|
||||
type="slide",
|
||||
)
|
||||
ET.SubElement(
|
||||
@ -513,6 +532,30 @@ def _add_object_ball(
|
||||
)
|
||||
|
||||
|
||||
def _add_objball_freejoint(
|
||||
worldbody: ET.Element,
|
||||
i: str,
|
||||
j: str,
|
||||
x: float,
|
||||
y: float,
|
||||
size: float,
|
||||
) -> None:
|
||||
body = ET.SubElement(worldbody, "body", name=f"objball_{i}_{j}", pos=f"{x} {y} 0")
|
||||
ET.SubElement(
|
||||
body,
|
||||
"geom",
|
||||
type="sphere",
|
||||
name=f"objball_{i}_{j}_geom",
|
||||
size=f"{size}", # Radius
|
||||
pos=f"0.0 0.0 {size}", # Z = size so that this ball can move!!
|
||||
rgba=maze_task.BLUE.rgba_str(),
|
||||
contype="1",
|
||||
conaffinity="1",
|
||||
solimp="0.9 0.99 0.001",
|
||||
)
|
||||
ET.SubElement(body, "freejoint", name=f"objball_{i}_{j}_root")
|
||||
|
||||
|
||||
def _add_movable_block(
|
||||
worldbody: ET.Element,
|
||||
struct: maze_env_utils.MazeCell,
|
||||
|
@ -258,6 +258,7 @@ class DistRewardFall(GoalRewardFall, DistRewardMixIn):
|
||||
class GoalRewardMultiFall(GoalRewardUMaze):
|
||||
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=None, swimmer=None)
|
||||
OBSERVE_BLOCKS: bool = True
|
||||
PENALTY: float = -0.0001
|
||||
|
||||
def __init__(self, scale: float, goal: Tuple[int, int] = (3.0, 1.0)) -> None:
|
||||
super().__init__(scale)
|
||||
@ -672,12 +673,16 @@ class BanditBilliard(SubGoalBilliard):
|
||||
|
||||
|
||||
class GoalRewardSmallBilliard(GoalRewardBilliard):
|
||||
MAZE_SIZE_SCALING: Scaling = Scaling(ant=1.5, point=4.0, swimmer=None)
|
||||
OBJECT_BALL_SIZE: float = 0.5
|
||||
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=4.0, swimmer=None)
|
||||
OBJECT_BALL_SIZE: float = 0.4
|
||||
GOAL_SIZE: float = 0.2
|
||||
|
||||
def __init__(self, scale: float, goal: Tuple[float, float] = (-1.0, -2.0)) -> None:
|
||||
super().__init__(scale, goal)
|
||||
|
||||
def _threshold(self) -> float:
|
||||
return self.OBJECT_BALL_SIZE + self.GOAL_SIZE
|
||||
|
||||
@staticmethod
|
||||
def create_maze() -> List[List[MazeCell]]:
|
||||
E, B = MazeCell.EMPTY, MazeCell.BLOCK
|
||||
|
@ -19,6 +19,7 @@ class PointEnv(AgentModel):
|
||||
ORI_IND: int = 2
|
||||
MANUAL_COLLISION: bool = True
|
||||
RADIUS: float = 0.4
|
||||
OBJBALL_TYPE: str = "hinge"
|
||||
|
||||
VELOCITY_LIMITS: float = 10.0
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user