BlockCarry
This commit is contained in:
parent
67c54afccd
commit
06cf7c9b8b
@ -515,6 +515,62 @@ class DistRewardBlockMaze(GoalRewardBlockMaze, DistRewardMixIn):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GoalRewardBlockCarry(MazeTask):
|
||||||
|
REWARD_THRESHOLD: float = 0.9
|
||||||
|
PENALTY: float = -0.0001
|
||||||
|
MAZE_SIZE_SCALING: Scaling = Scaling(ant=2.0, point=3.0, swimmer=None)
|
||||||
|
OBSERVE_BLOCKS: bool = True
|
||||||
|
GOAL_SIZE: float = 0.3
|
||||||
|
|
||||||
|
def __init__(self, scale: float, goal: Tuple[float, float] = (2.0, 0.0)) -> None:
|
||||||
|
super().__init__(scale)
|
||||||
|
self.goals.append(
|
||||||
|
MazeGoal(
|
||||||
|
np.array(goal) * scale,
|
||||||
|
threshold=self.GOAL_SIZE + 0.5,
|
||||||
|
custom_size=self.GOAL_SIZE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def reward(self, obs: np.ndarray) -> float:
|
||||||
|
object_pos = obs[3:6]
|
||||||
|
for goal in self.goals:
|
||||||
|
if goal.neighbor(object_pos):
|
||||||
|
return goal.reward_scale
|
||||||
|
return self.PENALTY
|
||||||
|
|
||||||
|
def termination(self, obs: np.ndarray) -> bool:
|
||||||
|
object_pos = obs[3:6]
|
||||||
|
for goal in self.goals:
|
||||||
|
if goal.neighbor(object_pos):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_maze() -> List[List[MazeCell]]:
|
||||||
|
E, B = MazeCell.EMPTY, MazeCell.BLOCK
|
||||||
|
R, M = MazeCell.ROBOT, MazeCell.XY_BLOCK
|
||||||
|
return [
|
||||||
|
[B, B, B, B, B],
|
||||||
|
[B, E, E, E, B],
|
||||||
|
[B, E, E, E, B],
|
||||||
|
[B, R, M, E, B],
|
||||||
|
[B, E, E, E, B],
|
||||||
|
[B, E, E, E, B],
|
||||||
|
[B, B, B, B, B],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DistRewardBlockCarry(GoalRewardBlockCarry):
|
||||||
|
def reward(self, obs: np.ndarray) -> float:
|
||||||
|
return -self.goals[0].euc_dist(obs[3:6]) / self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class NoRewardBlockCarry(GoalRewardBlockCarry):
|
||||||
|
def reward(self, _obs: np.ndarray) -> float:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
class GoalRewardBilliard(MazeTask):
|
class GoalRewardBilliard(MazeTask):
|
||||||
REWARD_THRESHOLD: float = 0.9
|
REWARD_THRESHOLD: float = 0.9
|
||||||
PENALTY: float = -0.0001
|
PENALTY: float = -0.0001
|
||||||
@ -633,6 +689,7 @@ class TaskRegistry:
|
|||||||
"BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze],
|
"BlockMaze": [DistRewardBlockMaze, GoalRewardBlockMaze],
|
||||||
"Corridor": [DistRewardCorridor, GoalRewardCorridor, NoRewardCorridor],
|
"Corridor": [DistRewardCorridor, GoalRewardCorridor, NoRewardCorridor],
|
||||||
"LongCorridor": [DistRewardLongCorridor, GoalRewardLongCorridor],
|
"LongCorridor": [DistRewardLongCorridor, GoalRewardLongCorridor],
|
||||||
|
"BlockCarry": [DistRewardBlockCarry, GoalRewardBlockCarry, NoRewardBlockCarry],
|
||||||
"Billiard": [
|
"Billiard": [
|
||||||
DistRewardBilliard, # v0
|
DistRewardBilliard, # v0
|
||||||
GoalRewardBilliard, # v1
|
GoalRewardBilliard, # v1
|
||||||
|
Loading…
Reference in New Issue
Block a user