Introduce BanditBilliard and change SubGoalBilliard to a more normal one

This commit is contained in:
kngwyu 2020-10-01 00:20:48 +09:00
parent 0ec69ab4e2
commit 28711cee19

View File

@ -396,19 +396,30 @@ class SubGoalBilliard(GoalRewardBilliard):
def __init__(
self,
scale: float,
primary_goal: Tuple[float, float] = (4.0, -2.0),
subgoal: Tuple[float, float] = (4.0, 2.0),
primary_goal: Tuple[float, float] = (2.0, -3.0),
subgoals: List[Tuple[float, float]] = [(-2.0, -3.0), (-2.0, 1.0), (2.0, 1.0)],
) -> None:
super().__init__(scale, primary_goal)
self.goals.append(
MazeGoal(
np.array(subgoal) * scale,
reward_scale=0.5,
rgb=GREEN,
threshold=self._threshold(),
custom_size=self.GOAL_SIZE,
for subgoal in subgoals:
self.goals.append(
MazeGoal(
np.array(subgoal) * scale,
reward_scale=0.5,
rgb=GREEN,
threshold=self._threshold(),
custom_size=self.GOAL_SIZE,
)
)
)
class BanditBilliard(SubGoalBilliard):
def __init__(
self,
scale: float,
primary_goal: Tuple[float, float] = (4.0, -2.0),
subgoals: List[Tuple[float, float]] = [(4.0, 2.0)],
) -> None:
super().__init__(scale, primary_goal, subgoals)
@staticmethod
def create_maze() -> List[List[MazeCell]]:
@ -435,7 +446,12 @@ class TaskRegistry:
"4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms],
"TRoom": [DistRewardTRoom, GoalRewardTRoom],
"BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze],
"Billiard": [DistRewardBilliard, GoalRewardBilliard, SubGoalBilliard],
"Billiard": [
DistRewardBilliard,
GoalRewardBilliard,
SubGoalBilliard,
BanditBilliard,
],
}
@staticmethod