Introduce BanditBilliard and change SubGoalBilliard to a more normal one

This commit is contained in:
kngwyu 2020-10-01 00:20:48 +09:00
parent 0ec69ab4e2
commit 28711cee19

View File

@ -396,19 +396,30 @@ class SubGoalBilliard(GoalRewardBilliard):
def __init__( def __init__(
self, self,
scale: float, scale: float,
primary_goal: Tuple[float, float] = (4.0, -2.0), primary_goal: Tuple[float, float] = (2.0, -3.0),
subgoal: Tuple[float, float] = (4.0, 2.0), subgoals: List[Tuple[float, float]] = [(-2.0, -3.0), (-2.0, 1.0), (2.0, 1.0)],
) -> None: ) -> None:
super().__init__(scale, primary_goal) super().__init__(scale, primary_goal)
self.goals.append( for subgoal in subgoals:
MazeGoal( self.goals.append(
np.array(subgoal) * scale, MazeGoal(
reward_scale=0.5, np.array(subgoal) * scale,
rgb=GREEN, reward_scale=0.5,
threshold=self._threshold(), rgb=GREEN,
custom_size=self.GOAL_SIZE, threshold=self._threshold(),
custom_size=self.GOAL_SIZE,
)
) )
)
class BanditBilliard(SubGoalBilliard):
def __init__(
self,
scale: float,
primary_goal: Tuple[float, float] = (4.0, -2.0),
subgoals: List[Tuple[float, float]] = [(4.0, 2.0)],
) -> None:
super().__init__(scale, primary_goal, subgoals)
@staticmethod @staticmethod
def create_maze() -> List[List[MazeCell]]: def create_maze() -> List[List[MazeCell]]:
@ -435,7 +446,12 @@ class TaskRegistry:
"4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms], "4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms],
"TRoom": [DistRewardTRoom, GoalRewardTRoom], "TRoom": [DistRewardTRoom, GoalRewardTRoom],
"BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze], "BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze],
"Billiard": [DistRewardBilliard, GoalRewardBilliard, SubGoalBilliard], "Billiard": [
DistRewardBilliard,
GoalRewardBilliard,
SubGoalBilliard,
BanditBilliard,
],
} }
@staticmethod @staticmethod