Modify collision detection
This commit is contained in:
		
							parent
							
								
									348e975e60
								
							
						
					
					
						commit
						7431980838
					
				| @ -19,6 +19,13 @@ for maze_id in MAZE_IDS: | |||||||
|         max_episode_steps=1000, |         max_episode_steps=1000, | ||||||
|         reward_threshold=-1000, |         reward_threshold=-1000, | ||||||
|     ) |     ) | ||||||
|  |     gym.envs.register( | ||||||
|  |         id="Ant{}-v1".format(maze_id), | ||||||
|  |         entry_point="mujoco_maze.ant_maze_env:AntMazeEnv", | ||||||
|  |         kwargs=dict(maze_size_scaling=8.0, **_get_kwargs(maze_id)), | ||||||
|  |         max_episode_steps=1000, | ||||||
|  |         reward_threshold=0.9, | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| for maze_id in MAZE_IDS: | for maze_id in MAZE_IDS: | ||||||
|     gym.envs.register( |     gym.envs.register( | ||||||
| @ -28,6 +35,13 @@ for maze_id in MAZE_IDS: | |||||||
|         max_episode_steps=1000, |         max_episode_steps=1000, | ||||||
|         reward_threshold=-1000, |         reward_threshold=-1000, | ||||||
|     ) |     ) | ||||||
|  |     gym.envs.register( | ||||||
|  |         id="Point{}-v1".format(maze_id), | ||||||
|  |         entry_point="mujoco_maze.point_maze_env:PointMazeEnv", | ||||||
|  |         kwargs=dict(**_get_kwargs(maze_id), dense_reward=False), | ||||||
|  |         max_episode_steps=1000, | ||||||
|  |         reward_threshold=0.9 | ||||||
|  |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| __version__ = "0.1.0" | __version__ = "0.1.0" | ||||||
|  | |||||||
| @ -4,7 +4,6 @@ from abc import ABC, abstractmethod | |||||||
| from gym.envs.mujoco.mujoco_env import MujocoEnv | from gym.envs.mujoco.mujoco_env import MujocoEnv | ||||||
| from gym.utils import EzPickle | from gym.utils import EzPickle | ||||||
| import numpy as np | import numpy as np | ||||||
| from typing import Tuple |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentModel(ABC, MujocoEnv, EzPickle): | class AgentModel(ABC, MujocoEnv, EzPickle): | ||||||
| @ -22,13 +21,13 @@ class AgentModel(ABC, MujocoEnv, EzPickle): | |||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
|     @abstractmethod |     @abstractmethod | ||||||
|     def get_xy(self) -> Tuple[float, float]: |     def get_xy(self) -> np.ndarray: | ||||||
|         """Returns the coordinate of the agent. |         """Returns the coordinate of the agent. | ||||||
|         """ |         """ | ||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
|     @abstractmethod |     @abstractmethod | ||||||
|     def set_xy(self, xy: Tuple[float, float]) -> None: |     def set_xy(self, xy: np.ndarray) -> None: | ||||||
|         """Set the coordinate of the agent. |         """Set the coordinate of the agent. | ||||||
|         """ |         """ | ||||||
|         pass |         pass | ||||||
|  | |||||||
| @ -3,29 +3,30 @@ | |||||||
|     <option timestep="0.02" integrator="RK4" /> |     <option timestep="0.02" integrator="RK4" /> | ||||||
|     <default> |     <default> | ||||||
|         <joint limited="false" armature="0" damping="0" /> |         <joint limited="false" armature="0" damping="0" /> | ||||||
|         <geom condim="3" conaffinity="0" margin="0" friction="1 0.5 0.5" rgba="0.8 0.6 0.4 1" density="100" /> |         <geom condim="3" conaffinity="0" margin="0" friction="1.0 0.5 0.5" | ||||||
|  |               solimp="0.99 0.999 0.001" rgba="0.8 0.6 0.4 1" density="100" /> | ||||||
|     </default> |     </default> | ||||||
|     <asset> |     <asset> | ||||||
|         <texture type="skybox" builtin="gradient" width="100" height="100" rgb1="1 1 1" rgb2="0 0 0" /> |         <texture type="skybox" builtin="gradient" width="100" height="100" rgb1="1 1 1" rgb2="0 0 0" /> | ||||||
|         <texture name="texgeom" type="cube" builtin="flat" mark="cross" width="127" height="1278" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01" /> |         <texture name="texgeom" type="cube" builtin="flat" mark="cross" width="127" height="1278" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01" /> | ||||||
|         <texture name="texplane" type="2d" builtin="checker" rgb1="0 0 0" rgb2="0.8 0.8 0.8" width="100" height="100" /> |         <texture name="texplane" type="2d" builtin="checker" rgb1="0 0 0" rgb2="0.8 0.8 0.8" width="100" height="100" /> | ||||||
|         <material name='MatPlane' texture="texplane" shininess="1" texrepeat="30 30" specular="1"  reflectance="0.5" /> |         <material name="MatPlane" texture="texplane" shininess="1" texrepeat="30 30" specular="1"  reflectance="0.5" /> | ||||||
|         <material name='geom' texture="texgeom" texuniform="true" /> |         <material name="geom" texture="texgeom" texuniform="true" /> | ||||||
|     </asset> |     </asset> | ||||||
|     <worldbody> |     <worldbody> | ||||||
|         <light directional="true" cutoff="100" exponent="1" diffuse="1 1 1" specular=".1 .1 .1" pos="0 0 1.3" dir="-0 0 -1.3" /> |         <light directional="true" cutoff="100" exponent="1" diffuse="1 1 1" specular=".1 .1 .1" pos="0 0 1.3" dir="-0 0 -1.3" /> | ||||||
|         <geom name='floor' material="MatPlane" pos='0 0 0' size='40 40 40' type='plane' conaffinity='1' rgba='0.8 0.9 0.8 1' condim='3' /> |         <geom name="floor" material="MatPlane" pos="0 0 0" size="40 40 40" type="plane" conaffinity="1" rgba="0.8 0.9 0.8 1" condim="3" /> | ||||||
|         <body name="torso" pos="0 0 0"> |         <body name="torso" pos="0 0 0"> | ||||||
|             <geom name="pointbody" type="sphere" size="0.5" pos="0 0 0.5" solimp="0.98 0.99 0.001" /> |             <geom name="pointbody" type="sphere" size="0.5" pos="0 0 0.5" /> | ||||||
|             <geom name="pointarrow" type="box" size="0.5 0.1 0.1" pos="0.6 0 0.5" solimp="0.98 0.99 0.001" /> |             <geom name="pointarrow" type="box" size="0.5 0.1 0.1" pos="0.6 0 0.5" /> | ||||||
|             <joint name='ballx' type='slide' axis='1 0 0' pos='0 0 0' /> |             <joint name="ballx" type="slide" axis="1 0 0" pos="0 0 0" /> | ||||||
|             <joint name='bally' type='slide' axis='0 1 0' pos='0 0 0' /> |             <joint name="bally" type="slide" axis="0 1 0" pos="0 0 0" /> | ||||||
|             <joint name='rot' type='hinge' axis='0 0 1' pos='0 0 0' limited="false" /> |             <joint name="rot" type="hinge" axis="0 0 1" pos="0 0 0" limited="false" /> | ||||||
|         </body> |         </body> | ||||||
|     </worldbody> |     </worldbody> | ||||||
|     <actuator> |     <actuator> | ||||||
|         <!-- Those are just dummy actuators for providing ranges --> |         <!-- Those are just dummy actuators for providing ranges --> | ||||||
|         <motor joint='ballx' ctrlrange="-1 1" ctrllimited="true" /> |         <motor joint="ballx" ctrlrange="-1 1" ctrllimited="true" /> | ||||||
|         <motor joint='rot' ctrlrange="-0.25 0.25" ctrllimited="true" /> |         <motor joint="rot" ctrlrange="-0.25 0.25" ctrllimited="true" /> | ||||||
|     </actuator> |     </actuator> | ||||||
| </mujoco> | </mujoco> | ||||||
|  | |||||||
| @ -36,6 +36,8 @@ class MazeEnv(gym.Env): | |||||||
|     MODEL_CLASS: Type[AgentModel] = AgentModel |     MODEL_CLASS: Type[AgentModel] = AgentModel | ||||||
| 
 | 
 | ||||||
|     MANUAL_COLLISION: bool = False |     MANUAL_COLLISION: bool = False | ||||||
|  |     # For preventing the point from going through the wall | ||||||
|  |     SIZE_EPS = 0.0001 | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
| @ -52,7 +54,7 @@ class MazeEnv(gym.Env): | |||||||
|         goal_sampler: Union[str, np.ndarray, Callable[[], np.ndarray]] = "default", |         goal_sampler: Union[str, np.ndarray, Callable[[], np.ndarray]] = "default", | ||||||
|         *args, |         *args, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ) -> None: | ||||||
|         self._maze_id = maze_id |         self._maze_id = maze_id | ||||||
| 
 | 
 | ||||||
|         xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE) |         xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE) | ||||||
| @ -117,85 +119,67 @@ class MazeEnv(gym.Env): | |||||||
|                     struct = maze_env_utils.Move.SpinXY |                     struct = maze_env_utils.Move.SpinXY | ||||||
|                 if self.elevated and struct not in [-1]: |                 if self.elevated and struct not in [-1]: | ||||||
|                     # Create elevated platform. |                     # Create elevated platform. | ||||||
|  |                     x = j * size_scaling - torso_x | ||||||
|  |                     y = i * size_scaling - torso_y | ||||||
|  |                     h = height / 2 * size_scaling | ||||||
|  |                     size = 0.5 * size_scaling + self.SIZE_EPS | ||||||
|                     ET.SubElement( |                     ET.SubElement( | ||||||
|                         worldbody, |                         worldbody, | ||||||
|                         "geom", |                         "geom", | ||||||
|                         name="elevated_%d_%d" % (i, j), |                         name=f"elevated_{i}_{j}", | ||||||
|                         pos="%f %f %f" |                         pos=f"{x} {y} {h}", | ||||||
|                         % ( |                         size=f"{size} {size} {h}", | ||||||
|                             j * size_scaling - torso_x, |  | ||||||
|                             i * size_scaling - torso_y, |  | ||||||
|                             height / 2 * size_scaling, |  | ||||||
|                         ), |  | ||||||
|                         size="%f %f %f" |  | ||||||
|                         % ( |  | ||||||
|                             0.5 * size_scaling, |  | ||||||
|                             0.5 * size_scaling, |  | ||||||
|                             height / 2 * size_scaling, |  | ||||||
|                         ), |  | ||||||
|                         type="box", |                         type="box", | ||||||
|                         material="", |                         material="", | ||||||
|                         contype="1", |                         contype="1", | ||||||
|                         conaffinity="1", |                         conaffinity="1", | ||||||
|                         rgba="0.9 0.9 0.9 1", |                         rgba="0.9 0.9 0.9 1", | ||||||
|                     ) |                     ) | ||||||
|                 if struct == 1:  # Unmovable block. |                 if struct == 1: | ||||||
|  |                     # Unmovable block. | ||||||
|                     # Offset all coordinates so that robot starts at the origin. |                     # Offset all coordinates so that robot starts at the origin. | ||||||
|  |                     x = j * size_scaling - torso_x | ||||||
|  |                     y = i * size_scaling - torso_y | ||||||
|  |                     h = height / 2 * size_scaling | ||||||
|  |                     size = 0.5 * size_scaling + self.SIZE_EPS | ||||||
|                     ET.SubElement( |                     ET.SubElement( | ||||||
|                         worldbody, |                         worldbody, | ||||||
|                         "geom", |                         "geom", | ||||||
|                         name="block_%d_%d" % (i, j), |                         name=f"block_{i}_{j}", | ||||||
|                         pos="%f %f %f" |                         pos=f"{x} {y} {h + height_offset}", | ||||||
|                         % ( |                         size=f"{size} {size} {h}", | ||||||
|                             j * size_scaling - torso_x, |  | ||||||
|                             i * size_scaling - torso_y, |  | ||||||
|                             height_offset + height / 2 * size_scaling, |  | ||||||
|                         ), |  | ||||||
|                         size="%f %f %f" |  | ||||||
|                         % ( |  | ||||||
|                             0.5 * size_scaling, |  | ||||||
|                             0.5 * size_scaling, |  | ||||||
|                             height / 2 * size_scaling, |  | ||||||
|                         ), |  | ||||||
|                         type="box", |                         type="box", | ||||||
|                         material="", |                         material="", | ||||||
|                         contype="1", |                         contype="1", | ||||||
|                         conaffinity="1", |                         conaffinity="1", | ||||||
|                         rgba="0.4 0.4 0.4 1", |                         rgba="0.4 0.4 0.4 1", | ||||||
|                     ) |                     ) | ||||||
|                 elif maze_env_utils.can_move(struct):  # Movable block. |                 elif maze_env_utils.can_move(struct): | ||||||
|  |                     # Movable block. | ||||||
|                     # The "falling" blocks are shrunk slightly and increased in mass to |                     # The "falling" blocks are shrunk slightly and increased in mass to | ||||||
|                     # ensure it can fall easily through a gap in the platform blocks. |                     # ensure it can fall easily through a gap in the platform blocks. | ||||||
|                     name = "movable_%d_%d" % (i, j) |                     name = "movable_%d_%d" % (i, j) | ||||||
|                     self.movable_blocks.append((name, struct)) |                     self.movable_blocks.append((name, struct)) | ||||||
|                     falling = maze_env_utils.can_move_z(struct) |                     falling = maze_env_utils.can_move_z(struct) | ||||||
|                     spinning = maze_env_utils.can_spin(struct) |                     spinning = maze_env_utils.can_spin(struct) | ||||||
|                     x_offset = 0.25 * size_scaling if spinning else 0.0 |  | ||||||
|                     y_offset = 0.0 |  | ||||||
|                     shrink = 0.1 if spinning else 0.99 if falling else 1.0 |                     shrink = 0.1 if spinning else 0.99 if falling else 1.0 | ||||||
|                     height_shrink = 0.1 if spinning else 1.0 |                     height_shrink = 0.1 if spinning else 1.0 | ||||||
|  |                     x = j * size_scaling - torso_x + 0.25 * size_scaling if spinning else 0.0 | ||||||
|  |                     y = i * size_scaling - torso_y | ||||||
|  |                     h = height / 2 * size_scaling * height_shrink | ||||||
|  |                     size = 0.5 * size_scaling * shrink + self.SIZE_EPS | ||||||
|                     movable_body = ET.SubElement( |                     movable_body = ET.SubElement( | ||||||
|                         worldbody, |                         worldbody, | ||||||
|                         "body", |                         "body", | ||||||
|                         name=name, |                         name=name, | ||||||
|                         pos="%f %f %f" |                         pos=f"{x} {y} {height_offset + h}", | ||||||
|                         % ( |  | ||||||
|                             j * size_scaling - torso_x + x_offset, |  | ||||||
|                             i * size_scaling - torso_y + y_offset, |  | ||||||
|                             height_offset + height / 2 * size_scaling * height_shrink, |  | ||||||
|                         ), |  | ||||||
|                     ) |                     ) | ||||||
|                     ET.SubElement( |                     ET.SubElement( | ||||||
|                         movable_body, |                         movable_body, | ||||||
|                         "geom", |                         "geom", | ||||||
|                         name="block_%d_%d" % (i, j), |                         name=f"block_{i}_{j}", | ||||||
|                         pos="0 0 0", |                         pos="0 0 0", | ||||||
|                         size="%f %f %f" |                         size=f"{size} {size} {h}", | ||||||
|                         % ( |  | ||||||
|                             0.5 * size_scaling * shrink, |  | ||||||
|                             0.5 * size_scaling * shrink, |  | ||||||
|                             height / 2 * size_scaling * height_shrink, |  | ||||||
|                         ), |  | ||||||
|                         type="box", |                         type="box", | ||||||
|                         material="", |                         material="", | ||||||
|                         mass="0.001" if falling else "0.0002", |                         mass="0.001" if falling else "0.0002", | ||||||
| @ -211,9 +195,9 @@ class MazeEnv(gym.Env): | |||||||
|                             axis="1 0 0", |                             axis="1 0 0", | ||||||
|                             damping="0.0", |                             damping="0.0", | ||||||
|                             limited="true" if falling else "false", |                             limited="true" if falling else "false", | ||||||
|                             range="%f %f" % (-size_scaling, size_scaling), |                             range=f"{-size_scaling} {size_scaling}", | ||||||
|                             margin="0.01", |                             margin="0.01", | ||||||
|                             name="movable_x_%d_%d" % (i, j), |                             name=f"movable_x_{i}_{j}", | ||||||
|                             pos="0 0 0", |                             pos="0 0 0", | ||||||
|                             type="slide", |                             type="slide", | ||||||
|                         ) |                         ) | ||||||
| @ -225,9 +209,9 @@ class MazeEnv(gym.Env): | |||||||
|                             axis="0 1 0", |                             axis="0 1 0", | ||||||
|                             damping="0.0", |                             damping="0.0", | ||||||
|                             limited="true" if falling else "false", |                             limited="true" if falling else "false", | ||||||
|                             range="%f %f" % (-size_scaling, size_scaling), |                             range=f"{-size_scaling} {size_scaling}", | ||||||
|                             margin="0.01", |                             margin="0.01", | ||||||
|                             name="movable_y_%d_%d" % (i, j), |                             name=f"movable_y_{i}_{j}", | ||||||
|                             pos="0 0 0", |                             pos="0 0 0", | ||||||
|                             type="slide", |                             type="slide", | ||||||
|                         ) |                         ) | ||||||
| @ -239,9 +223,9 @@ class MazeEnv(gym.Env): | |||||||
|                             axis="0 0 1", |                             axis="0 0 1", | ||||||
|                             damping="0.0", |                             damping="0.0", | ||||||
|                             limited="true", |                             limited="true", | ||||||
|                             range="%f 0" % (-height_offset), |                             range=f"{-height_offset} 0", | ||||||
|                             margin="0.01", |                             margin="0.01", | ||||||
|                             name="movable_z_%d_%d" % (i, j), |                             name=f"movable_z_{i}_{j}", | ||||||
|                             pos="0 0 0", |                             pos="0 0 0", | ||||||
|                             type="slide", |                             type="slide", | ||||||
|                         ) |                         ) | ||||||
| @ -253,7 +237,7 @@ class MazeEnv(gym.Env): | |||||||
|                             axis="0 0 1", |                             axis="0 0 1", | ||||||
|                             damping="0.0", |                             damping="0.0", | ||||||
|                             limited="false", |                             limited="false", | ||||||
|                             name="spinable_%d_%d" % (i, j), |                             name=f"spinable_{i}_{j}", | ||||||
|                             pos="0 0 0", |                             pos="0 0 0", | ||||||
|                             type="ball", |                             type="ball", | ||||||
|                         ) |                         ) | ||||||
| @ -545,7 +529,8 @@ class MazeEnv(gym.Env): | |||||||
|         if self.MANUAL_COLLISION: |         if self.MANUAL_COLLISION: | ||||||
|             old_pos = self.wrapped_env.get_xy() |             old_pos = self.wrapped_env.get_xy() | ||||||
|             inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) |             inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) | ||||||
|             if self._collision.is_in(self.wrapped_env.get_xy()): |             new_pos = self.wrapped_env.get_xy() | ||||||
|  |             if self._collision.is_in(old_pos, new_pos): | ||||||
|                 self.wrapped_env.set_xy(old_pos) |                 self.wrapped_env.set_xy(old_pos) | ||||||
|         else: |         else: | ||||||
|             inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) |             inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) | ||||||
| @ -574,9 +559,17 @@ def _reward_fn(maze_id: str, dense: str) -> callable: | |||||||
|             raise NotImplementedError(f"Unknown maze id: {maze_id}") |             raise NotImplementedError(f"Unknown maze id: {maze_id}") | ||||||
|     else: |     else: | ||||||
|         if maze_id in ["Maze", "Push", "BlockMaze"]: |         if maze_id in ["Maze", "Push", "BlockMaze"]: | ||||||
|             return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0 |             return ( | ||||||
|  |                 lambda obs, goal: -0.001 | ||||||
|  |                 if np.linalg.norm(obs[:2] - goal) <= 0.6 | ||||||
|  |                 else 1.0 | ||||||
|  |             ) | ||||||
|         elif maze_id == "Fall": |         elif maze_id == "Fall": | ||||||
|             return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0 |             return ( | ||||||
|  |                 lambda obs, goal: -0.001 | ||||||
|  |                 if np.linalg.norm(obs[:3] - goal) <= 0.6 | ||||||
|  |                 else 1.0 | ||||||
|  |             ) | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError(f"Unknown maze id: {maze_id}") |             raise NotImplementedError(f"Unknown maze id: {maze_id}") | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -104,7 +104,7 @@ class Collision: | |||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     ARROUND = np.array([[-1, 0], [1, 0], [0, -1], [0, 1]]) |     ARROUND = np.array([[-1, 0], [1, 0], [0, -1], [0, 1]]) | ||||||
|     OFFSET = {False: 0.45, True: 0.5} |     OFFSET = {False: 0.499, True: 0.501} | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, structure: list, size_scaling: float, torso_x: float, torso_y: float, |         self, structure: list, size_scaling: float, torso_x: float, torso_y: float, | ||||||
| @ -134,11 +134,11 @@ class Collision: | |||||||
|             max_x = x_base + size_scaling * offset(pos, 3) |             max_x = x_base + size_scaling * offset(pos, 3) | ||||||
|             self.objects.append((min_y, max_y, min_x, max_x)) |             self.objects.append((min_y, max_y, min_x, max_x)) | ||||||
| 
 | 
 | ||||||
|     def is_in(self, pos) -> bool: |     def is_in(self, old_pos, new_pos) -> bool: | ||||||
|         x, y = pos |         for x, y in (new_pos, (old_pos + new_pos) / 2): | ||||||
|         for min_y, max_y, min_x, max_x in self.objects: |             for min_y, max_y, min_x, max_x in self.objects: | ||||||
|             if min_x <= x <= max_x and min_y <= y <= max_y: |                 if min_x <= x <= max_x and min_y <= y <= max_y: | ||||||
|                 return True |                     return True | ||||||
|         return False |         return False | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -78,8 +78,7 @@ class PointEnv(AgentModel): | |||||||
|         return self._get_obs() |         return self._get_obs() | ||||||
| 
 | 
 | ||||||
|     def get_xy(self): |     def get_xy(self): | ||||||
|         qpos = self.sim.data.qpos |         return self.sim.data.qpos[:2] | ||||||
|         return qpos[0], qpos[1] |  | ||||||
| 
 | 
 | ||||||
|     def set_xy(self, xy): |     def set_xy(self, xy): | ||||||
|         qpos = np.copy(self.sim.data.qpos) |         qpos = np.copy(self.sim.data.qpos) | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user