BlockCarry
This commit is contained in:
parent
67c54afccd
commit
06cf7c9b8b
@ -515,6 +515,62 @@ class DistRewardBlockMaze(GoalRewardBlockMaze, DistRewardMixIn):
|
||||
pass
|
||||
|
||||
|
||||
class GoalRewardBlockCarry(MazeTask):
|
||||
REWARD_THRESHOLD: float = 0.9
|
||||
PENALTY: float = -0.0001
|
||||
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=3.0, swimmer=None)
|
||||
OBSERVE_BLOCKS: bool = True
|
||||
GOAL_SIZE: float = 0.3
|
||||
|
||||
def __init__(self, scale: float, goal: Tuple[float, float] = (2.0, 0.0)) -> None:
|
||||
super().__init__(scale)
|
||||
self.goals.append(
|
||||
MazeGoal(
|
||||
np.array(goal) * scale,
|
||||
threshold=self.GOAL_SIZE + 0.5,
|
||||
custom_size=self.GOAL_SIZE,
|
||||
)
|
||||
)
|
||||
|
||||
def reward(self, obs: np.ndarray) -> float:
|
||||
object_pos = obs[3:6]
|
||||
for goal in self.goals:
|
||||
if goal.neighbor(object_pos):
|
||||
return goal.reward_scale
|
||||
return self.PENALTY
|
||||
|
||||
def termination(self, obs: np.ndarray) -> bool:
|
||||
object_pos = obs[3:6]
|
||||
for goal in self.goals:
|
||||
if goal.neighbor(object_pos):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def create_maze() -> List[List[MazeCell]]:
|
||||
E, B = MazeCell.EMPTY, MazeCell.BLOCK
|
||||
R, M = MazeCell.ROBOT, MazeCell.XY_BLOCK
|
||||
return [
|
||||
[B, B, B, B, B],
|
||||
[B, E, E, E, B],
|
||||
[B, E, E, E, B],
|
||||
[B, R, M, E, B],
|
||||
[B, E, E, E, B],
|
||||
[B, E, E, E, B],
|
||||
[B, B, B, B, B],
|
||||
]
|
||||
|
||||
|
||||
class DistRewardBlockCarry(GoalRewardBlockCarry):
|
||||
def reward(self, obs: np.ndarray) -> float:
|
||||
return -self.goals[0].euc_dist(obs[3:6]) / self.scale
|
||||
|
||||
|
||||
class NoRewardBlockCarry(GoalRewardBlockCarry):
|
||||
def reward(self, _obs: np.ndarray) -> float:
|
||||
return 0.0
|
||||
|
||||
|
||||
class GoalRewardBilliard(MazeTask):
|
||||
REWARD_THRESHOLD: float = 0.9
|
||||
PENALTY: float = -0.0001
|
||||
@ -633,6 +689,7 @@ class TaskRegistry:
|
||||
"BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze],
|
||||
"Corridor": [DistRewardCorridor, GoalRewardCorridor, NoRewardCorridor],
|
||||
"LongCorridor": [DistRewardLongCorridor, GoalRewardLongCorridor],
|
||||
"BlockCarry": [DistRewardBlockCarry, GoalRewardBlockCarry, NoRewardBlockCarry],
|
||||
"Billiard": [
|
||||
DistRewardBilliard, # v0
|
||||
GoalRewardBilliard, # v1
|
||||
|
Loading…
Reference in New Issue
Block a user