Ant Billiard
This commit is contained in:
		
							parent
							
								
									06cf7c9b8b
								
							
						
					
					
						commit
						98fe8b977e
					
				| @ -75,7 +75,6 @@ class MazeEnv(gym.Env): | ||||
|                 torso_y, | ||||
|                 model_cls.RADIUS, | ||||
|             ) | ||||
|             # Now all object balls have size=1.0 | ||||
|             self._objball_collision = maze_env_utils.CollisionDetector( | ||||
|                 structure, | ||||
|                 size_scaling, | ||||
| @ -490,7 +489,8 @@ def _add_object_ball( | ||||
|         "joint", | ||||
|         name=f"objball_{i}_{j}_x", | ||||
|         axis="1 0 0", | ||||
|         pos="0 0 0.0", | ||||
|         pos="0 0 0", | ||||
|         range="-1 1", | ||||
|         type="slide", | ||||
|     ) | ||||
|     ET.SubElement( | ||||
| @ -499,6 +499,7 @@ def _add_object_ball( | ||||
|         name=f"objball_{i}_{j}_y", | ||||
|         axis="0 1 0", | ||||
|         pos="0 0 0", | ||||
|         range="-1 1", | ||||
|         type="slide", | ||||
|     ) | ||||
|     ET.SubElement( | ||||
|  | ||||
| @ -256,7 +256,7 @@ class DistRewardFall(GoalRewardFall, DistRewardMixIn): | ||||
| 
 | ||||
| 
 | ||||
| class GoalRewardMultiFall(GoalRewardUMaze): | ||||
|     MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=6.0, swimmer=None) | ||||
|     MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=None, swimmer=None) | ||||
|     OBSERVE_BLOCKS: bool = True | ||||
| 
 | ||||
|     def __init__(self, scale: float, goal: Tuple[int, int] = (0.0, 4.0)) -> None: | ||||
| @ -674,6 +674,36 @@ class BanditBilliard(SubGoalBilliard): | ||||
|         ] | ||||
| 
 | ||||
| 
 | ||||
| class GoalRewardSmallBilliard(GoalRewardBilliard): | ||||
|     MAZE_SIZE_SCALING: Scaling = Scaling(ant=1.5, point=4.0, swimmer=None) | ||||
|     OBJECT_BALL_SIZE: float = 0.5 | ||||
| 
 | ||||
|     def __init__(self, scale: float, goal: Tuple[float, float] = (-21.0, -2.0)) -> None: | ||||
|         super().__init__(scale, goal) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def create_maze() -> List[List[MazeCell]]: | ||||
|         E, B = MazeCell.EMPTY, MazeCell.BLOCK | ||||
|         R, M = MazeCell.ROBOT, MazeCell.OBJECT_BALL | ||||
|         return [ | ||||
|             [B, B, B, B, B], | ||||
|             [B, E, E, E, B], | ||||
|             [B, E, M, E, B], | ||||
|             [B, E, R, E, B], | ||||
|             [B, E, E, E, B], | ||||
|             [B, B, B, B, B], | ||||
|         ] | ||||
| 
 | ||||
| 
 | ||||
| class DistRewardSmallBilliard(GoalRewardSmallBilliard, DistRewardMixIn): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| class NoRewardSmallBilliard(GoalRewardSmallBilliard): | ||||
|     def reward(self, _obs: np.ndarray) -> float: | ||||
|         return 0.0 | ||||
| 
 | ||||
| 
 | ||||
| class TaskRegistry: | ||||
|     REGISTRY: Dict[str, List[Type[MazeTask]]] = { | ||||
|         "SimpleRoom": [DistRewardSimpleRoom, GoalRewardSimpleRoom], | ||||
| @ -697,6 +727,11 @@ class TaskRegistry: | ||||
|             BanditBilliard,  # v3 | ||||
|             NoRewardBilliard,  # v4 | ||||
|         ], | ||||
|         "SmallBilliard": [ | ||||
|             DistRewardSmallBilliard, | ||||
|             GoalRewardSmallBilliard, | ||||
|             NoRewardSmallBilliard, | ||||
|         ], | ||||
|     } | ||||
| 
 | ||||
|     @staticmethod | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user