Ant Billiard
This commit is contained in:
		
							parent
							
								
									06cf7c9b8b
								
							
						
					
					
						commit
						98fe8b977e
					
				| @ -75,7 +75,6 @@ class MazeEnv(gym.Env): | |||||||
|                 torso_y, |                 torso_y, | ||||||
|                 model_cls.RADIUS, |                 model_cls.RADIUS, | ||||||
|             ) |             ) | ||||||
|             # Now all object balls have size=1.0 |  | ||||||
|             self._objball_collision = maze_env_utils.CollisionDetector( |             self._objball_collision = maze_env_utils.CollisionDetector( | ||||||
|                 structure, |                 structure, | ||||||
|                 size_scaling, |                 size_scaling, | ||||||
| @ -490,7 +489,8 @@ def _add_object_ball( | |||||||
|         "joint", |         "joint", | ||||||
|         name=f"objball_{i}_{j}_x", |         name=f"objball_{i}_{j}_x", | ||||||
|         axis="1 0 0", |         axis="1 0 0", | ||||||
|         pos="0 0 0.0", |         pos="0 0 0", | ||||||
|  |         range="-1 1", | ||||||
|         type="slide", |         type="slide", | ||||||
|     ) |     ) | ||||||
|     ET.SubElement( |     ET.SubElement( | ||||||
| @ -499,6 +499,7 @@ def _add_object_ball( | |||||||
|         name=f"objball_{i}_{j}_y", |         name=f"objball_{i}_{j}_y", | ||||||
|         axis="0 1 0", |         axis="0 1 0", | ||||||
|         pos="0 0 0", |         pos="0 0 0", | ||||||
|  |         range="-1 1", | ||||||
|         type="slide", |         type="slide", | ||||||
|     ) |     ) | ||||||
|     ET.SubElement( |     ET.SubElement( | ||||||
|  | |||||||
| @ -256,7 +256,7 @@ class DistRewardFall(GoalRewardFall, DistRewardMixIn): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class GoalRewardMultiFall(GoalRewardUMaze): | class GoalRewardMultiFall(GoalRewardUMaze): | ||||||
|     MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=6.0, swimmer=None) |     MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=None, swimmer=None) | ||||||
|     OBSERVE_BLOCKS: bool = True |     OBSERVE_BLOCKS: bool = True | ||||||
| 
 | 
 | ||||||
|     def __init__(self, scale: float, goal: Tuple[int, int] = (0.0, 4.0)) -> None: |     def __init__(self, scale: float, goal: Tuple[int, int] = (0.0, 4.0)) -> None: | ||||||
| @ -674,6 +674,36 @@ class BanditBilliard(SubGoalBilliard): | |||||||
|         ] |         ] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class GoalRewardSmallBilliard(GoalRewardBilliard): | ||||||
|  |     MAZE_SIZE_SCALING: Scaling = Scaling(ant=1.5, point=4.0, swimmer=None) | ||||||
|  |     OBJECT_BALL_SIZE: float = 0.5 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, scale: float, goal: Tuple[float, float] = (-21.0, -2.0)) -> None: | ||||||
|  |         super().__init__(scale, goal) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def create_maze() -> List[List[MazeCell]]: | ||||||
|  |         E, B = MazeCell.EMPTY, MazeCell.BLOCK | ||||||
|  |         R, M = MazeCell.ROBOT, MazeCell.OBJECT_BALL | ||||||
|  |         return [ | ||||||
|  |             [B, B, B, B, B], | ||||||
|  |             [B, E, E, E, B], | ||||||
|  |             [B, E, M, E, B], | ||||||
|  |             [B, E, R, E, B], | ||||||
|  |             [B, E, E, E, B], | ||||||
|  |             [B, B, B, B, B], | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DistRewardSmallBilliard(GoalRewardSmallBilliard, DistRewardMixIn): | ||||||
|  |     pass | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class NoRewardSmallBilliard(GoalRewardSmallBilliard): | ||||||
|  |     def reward(self, _obs: np.ndarray) -> float: | ||||||
|  |         return 0.0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class TaskRegistry: | class TaskRegistry: | ||||||
|     REGISTRY: Dict[str, List[Type[MazeTask]]] = { |     REGISTRY: Dict[str, List[Type[MazeTask]]] = { | ||||||
|         "SimpleRoom": [DistRewardSimpleRoom, GoalRewardSimpleRoom], |         "SimpleRoom": [DistRewardSimpleRoom, GoalRewardSimpleRoom], | ||||||
| @ -697,6 +727,11 @@ class TaskRegistry: | |||||||
|             BanditBilliard,  # v3 |             BanditBilliard,  # v3 | ||||||
|             NoRewardBilliard,  # v4 |             NoRewardBilliard,  # v4 | ||||||
|         ], |         ], | ||||||
|  |         "SmallBilliard": [ | ||||||
|  |             DistRewardSmallBilliard, | ||||||
|  |             GoalRewardSmallBilliard, | ||||||
|  |             NoRewardSmallBilliard, | ||||||
|  |         ], | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user