AntBilliard

This commit is contained in:
Yuji Kanagawa 2021-10-03 11:24:15 +09:00
parent 30c845fb0c
commit cd55df02b1
5 changed files with 58 additions and 7 deletions

View File

@ -13,6 +13,7 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
MANUAL_COLLISION: bool
ORI_IND: Optional[int] = None
RADIUS: Optional[float] = None
OBJBALL_TYPE: Optional[str] = None
def __init__(self, file_path: str, frame_skip: int) -> None:
MujocoEnv.__init__(self, file_path, frame_skip)

View File

@ -39,6 +39,7 @@ class AntEnv(AgentModel):
FILE: str = "ant.xml"
ORI_IND: int = 3
MANUAL_COLLISION: bool = False
OBJBALL_TYPE: str = "freejoint"
def __init__(
self,

View File

@ -165,7 +165,28 @@ class MazeEnv(gym.Env):
elif struct.is_object_ball():
# Movable Ball
self.object_balls.append(f"objball_{i}_{j}")
_add_object_ball(worldbody, i, j, x, y, self._task.OBJECT_BALL_SIZE)
if model_cls.OBJBALL_TYPE == "hinge":
_add_objball_hinge(
worldbody,
i,
j,
x,
y,
self._task.OBJECT_BALL_SIZE,
)
elif model_cls.OBJBALL_TYPE == "freejoint":
_add_objball_freejoint(
worldbody,
i,
j,
x,
y,
self._task.OBJECT_BALL_SIZE,
)
else:
raise ValueError(
f"OBJBALL_TYPE is not registered for {model_cls}"
)
torso = tree.find(".//body[@name='torso']")
geoms = torso.findall(".//geom")
@ -185,7 +206,7 @@ class MazeEnv(gym.Env):
"site",
name=f"goal_site{i}",
pos=f"{goal.pos[0]} {goal.pos[1]} {z}",
size=f"{maze_size_scaling * 0.1}",
size=size,
rgba=goal.rgb.rgba_str(),
)
@ -461,7 +482,7 @@ class MazeEnv(gym.Env):
self._websock_server_pipe.send(None)
def _add_object_ball(
def _add_objball_hinge(
worldbody: ET.Element,
i: str,
j: str,
@ -490,7 +511,6 @@ def _add_object_ball(
name=f"objball_{i}_{j}_x",
axis="1 0 0",
pos="0 0 0",
range="-1 1",
type="slide",
)
ET.SubElement(
@ -499,7 +519,6 @@ def _add_object_ball(
name=f"objball_{i}_{j}_y",
axis="0 1 0",
pos="0 0 0",
range="-1 1",
type="slide",
)
ET.SubElement(
@ -513,6 +532,30 @@ def _add_object_ball(
)
def _add_objball_freejoint(
worldbody: ET.Element,
i: str,
j: str,
x: float,
y: float,
size: float,
) -> None:
body = ET.SubElement(worldbody, "body", name=f"objball_{i}_{j}", pos=f"{x} {y} 0")
ET.SubElement(
body,
"geom",
type="sphere",
name=f"objball_{i}_{j}_geom",
size=f"{size}", # Radius
pos=f"0.0 0.0 {size}", # Z = size so that this ball can move!!
rgba=maze_task.BLUE.rgba_str(),
contype="1",
conaffinity="1",
solimp="0.9 0.99 0.001",
)
ET.SubElement(body, "freejoint", name=f"objball_{i}_{j}_root")
def _add_movable_block(
worldbody: ET.Element,
struct: maze_env_utils.MazeCell,

View File

@ -258,6 +258,7 @@ class DistRewardFall(GoalRewardFall, DistRewardMixIn):
class GoalRewardMultiFall(GoalRewardUMaze):
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=None, swimmer=None)
OBSERVE_BLOCKS: bool = True
PENALTY: float = -0.0001
def __init__(self, scale: float, goal: Tuple[int, int] = (3.0, 1.0)) -> None:
super().__init__(scale)
@ -672,12 +673,16 @@ class BanditBilliard(SubGoalBilliard):
class GoalRewardSmallBilliard(GoalRewardBilliard):
MAZE_SIZE_SCALING: Scaling = Scaling(ant=1.5, point=4.0, swimmer=None)
OBJECT_BALL_SIZE: float = 0.5
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=4.0, swimmer=None)
OBJECT_BALL_SIZE: float = 0.4
GOAL_SIZE: float = 0.2
def __init__(self, scale: float, goal: Tuple[float, float] = (-1.0, -2.0)) -> None:
super().__init__(scale, goal)
def _threshold(self) -> float:
return self.OBJECT_BALL_SIZE + self.GOAL_SIZE
@staticmethod
def create_maze() -> List[List[MazeCell]]:
E, B = MazeCell.EMPTY, MazeCell.BLOCK

View File

@ -19,6 +19,7 @@ class PointEnv(AgentModel):
ORI_IND: int = 2
MANUAL_COLLISION: bool = True
RADIUS: float = 0.4
OBJBALL_TYPE: str = "hinge"
VELOCITY_LIMITS: float = 10.0