AntBilliard
This commit is contained in:
parent
30c845fb0c
commit
cd55df02b1
@ -13,6 +13,7 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
|
|||||||
MANUAL_COLLISION: bool
|
MANUAL_COLLISION: bool
|
||||||
ORI_IND: Optional[int] = None
|
ORI_IND: Optional[int] = None
|
||||||
RADIUS: Optional[float] = None
|
RADIUS: Optional[float] = None
|
||||||
|
OBJBALL_TYPE: Optional[str] = None
|
||||||
|
|
||||||
def __init__(self, file_path: str, frame_skip: int) -> None:
|
def __init__(self, file_path: str, frame_skip: int) -> None:
|
||||||
MujocoEnv.__init__(self, file_path, frame_skip)
|
MujocoEnv.__init__(self, file_path, frame_skip)
|
||||||
|
@ -39,6 +39,7 @@ class AntEnv(AgentModel):
|
|||||||
FILE: str = "ant.xml"
|
FILE: str = "ant.xml"
|
||||||
ORI_IND: int = 3
|
ORI_IND: int = 3
|
||||||
MANUAL_COLLISION: bool = False
|
MANUAL_COLLISION: bool = False
|
||||||
|
OBJBALL_TYPE: str = "freejoint"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -165,7 +165,28 @@ class MazeEnv(gym.Env):
|
|||||||
elif struct.is_object_ball():
|
elif struct.is_object_ball():
|
||||||
# Movable Ball
|
# Movable Ball
|
||||||
self.object_balls.append(f"objball_{i}_{j}")
|
self.object_balls.append(f"objball_{i}_{j}")
|
||||||
_add_object_ball(worldbody, i, j, x, y, self._task.OBJECT_BALL_SIZE)
|
if model_cls.OBJBALL_TYPE == "hinge":
|
||||||
|
_add_objball_hinge(
|
||||||
|
worldbody,
|
||||||
|
i,
|
||||||
|
j,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
self._task.OBJECT_BALL_SIZE,
|
||||||
|
)
|
||||||
|
elif model_cls.OBJBALL_TYPE == "freejoint":
|
||||||
|
_add_objball_freejoint(
|
||||||
|
worldbody,
|
||||||
|
i,
|
||||||
|
j,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
self._task.OBJECT_BALL_SIZE,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"OBJBALL_TYPE is not registered for {model_cls}"
|
||||||
|
)
|
||||||
|
|
||||||
torso = tree.find(".//body[@name='torso']")
|
torso = tree.find(".//body[@name='torso']")
|
||||||
geoms = torso.findall(".//geom")
|
geoms = torso.findall(".//geom")
|
||||||
@ -185,7 +206,7 @@ class MazeEnv(gym.Env):
|
|||||||
"site",
|
"site",
|
||||||
name=f"goal_site{i}",
|
name=f"goal_site{i}",
|
||||||
pos=f"{goal.pos[0]} {goal.pos[1]} {z}",
|
pos=f"{goal.pos[0]} {goal.pos[1]} {z}",
|
||||||
size=f"{maze_size_scaling * 0.1}",
|
size=size,
|
||||||
rgba=goal.rgb.rgba_str(),
|
rgba=goal.rgb.rgba_str(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -461,7 +482,7 @@ class MazeEnv(gym.Env):
|
|||||||
self._websock_server_pipe.send(None)
|
self._websock_server_pipe.send(None)
|
||||||
|
|
||||||
|
|
||||||
def _add_object_ball(
|
def _add_objball_hinge(
|
||||||
worldbody: ET.Element,
|
worldbody: ET.Element,
|
||||||
i: str,
|
i: str,
|
||||||
j: str,
|
j: str,
|
||||||
@ -490,7 +511,6 @@ def _add_object_ball(
|
|||||||
name=f"objball_{i}_{j}_x",
|
name=f"objball_{i}_{j}_x",
|
||||||
axis="1 0 0",
|
axis="1 0 0",
|
||||||
pos="0 0 0",
|
pos="0 0 0",
|
||||||
range="-1 1",
|
|
||||||
type="slide",
|
type="slide",
|
||||||
)
|
)
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
@ -499,7 +519,6 @@ def _add_object_ball(
|
|||||||
name=f"objball_{i}_{j}_y",
|
name=f"objball_{i}_{j}_y",
|
||||||
axis="0 1 0",
|
axis="0 1 0",
|
||||||
pos="0 0 0",
|
pos="0 0 0",
|
||||||
range="-1 1",
|
|
||||||
type="slide",
|
type="slide",
|
||||||
)
|
)
|
||||||
ET.SubElement(
|
ET.SubElement(
|
||||||
@ -513,6 +532,30 @@ def _add_object_ball(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_objball_freejoint(
|
||||||
|
worldbody: ET.Element,
|
||||||
|
i: str,
|
||||||
|
j: str,
|
||||||
|
x: float,
|
||||||
|
y: float,
|
||||||
|
size: float,
|
||||||
|
) -> None:
|
||||||
|
body = ET.SubElement(worldbody, "body", name=f"objball_{i}_{j}", pos=f"{x} {y} 0")
|
||||||
|
ET.SubElement(
|
||||||
|
body,
|
||||||
|
"geom",
|
||||||
|
type="sphere",
|
||||||
|
name=f"objball_{i}_{j}_geom",
|
||||||
|
size=f"{size}", # Radius
|
||||||
|
pos=f"0.0 0.0 {size}", # Z = size so that this ball can move!!
|
||||||
|
rgba=maze_task.BLUE.rgba_str(),
|
||||||
|
contype="1",
|
||||||
|
conaffinity="1",
|
||||||
|
solimp="0.9 0.99 0.001",
|
||||||
|
)
|
||||||
|
ET.SubElement(body, "freejoint", name=f"objball_{i}_{j}_root")
|
||||||
|
|
||||||
|
|
||||||
def _add_movable_block(
|
def _add_movable_block(
|
||||||
worldbody: ET.Element,
|
worldbody: ET.Element,
|
||||||
struct: maze_env_utils.MazeCell,
|
struct: maze_env_utils.MazeCell,
|
||||||
|
@ -258,6 +258,7 @@ class DistRewardFall(GoalRewardFall, DistRewardMixIn):
|
|||||||
class GoalRewardMultiFall(GoalRewardUMaze):
|
class GoalRewardMultiFall(GoalRewardUMaze):
|
||||||
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=None, swimmer=None)
|
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=None, swimmer=None)
|
||||||
OBSERVE_BLOCKS: bool = True
|
OBSERVE_BLOCKS: bool = True
|
||||||
|
PENALTY: float = -0.0001
|
||||||
|
|
||||||
def __init__(self, scale: float, goal: Tuple[int, int] = (3.0, 1.0)) -> None:
|
def __init__(self, scale: float, goal: Tuple[int, int] = (3.0, 1.0)) -> None:
|
||||||
super().__init__(scale)
|
super().__init__(scale)
|
||||||
@ -672,12 +673,16 @@ class BanditBilliard(SubGoalBilliard):
|
|||||||
|
|
||||||
|
|
||||||
class GoalRewardSmallBilliard(GoalRewardBilliard):
|
class GoalRewardSmallBilliard(GoalRewardBilliard):
|
||||||
MAZE_SIZE_SCALING: Scaling = Scaling(ant=1.5, point=4.0, swimmer=None)
|
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=4.0, swimmer=None)
|
||||||
OBJECT_BALL_SIZE: float = 0.5
|
OBJECT_BALL_SIZE: float = 0.4
|
||||||
|
GOAL_SIZE: float = 0.2
|
||||||
|
|
||||||
def __init__(self, scale: float, goal: Tuple[float, float] = (-1.0, -2.0)) -> None:
|
def __init__(self, scale: float, goal: Tuple[float, float] = (-1.0, -2.0)) -> None:
|
||||||
super().__init__(scale, goal)
|
super().__init__(scale, goal)
|
||||||
|
|
||||||
|
def _threshold(self) -> float:
|
||||||
|
return self.OBJECT_BALL_SIZE + self.GOAL_SIZE
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_maze() -> List[List[MazeCell]]:
|
def create_maze() -> List[List[MazeCell]]:
|
||||||
E, B = MazeCell.EMPTY, MazeCell.BLOCK
|
E, B = MazeCell.EMPTY, MazeCell.BLOCK
|
||||||
|
@ -19,6 +19,7 @@ class PointEnv(AgentModel):
|
|||||||
ORI_IND: int = 2
|
ORI_IND: int = 2
|
||||||
MANUAL_COLLISION: bool = True
|
MANUAL_COLLISION: bool = True
|
||||||
RADIUS: float = 0.4
|
RADIUS: float = 0.4
|
||||||
|
OBJBALL_TYPE: str = "hinge"
|
||||||
|
|
||||||
VELOCITY_LIMITS: float = 10.0
|
VELOCITY_LIMITS: float = 10.0
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user