Scaling=None if the environment is not supported
This commit is contained in:
parent
3384934aff
commit
bf4e5b1e97
@ -16,69 +16,66 @@ from mujoco_maze.swimmer import SwimmerEnv
|
||||
|
||||
for maze_id in TaskRegistry.keys():
|
||||
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
|
||||
# Point
|
||||
gym.envs.register(
|
||||
id=f"Point{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||
kwargs=dict(
|
||||
model_cls=PointEnv,
|
||||
maze_task=task_cls,
|
||||
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.point,
|
||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||
),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
if "Billiard" in maze_id:
|
||||
continue
|
||||
# Ant
|
||||
gym.envs.register(
|
||||
id=f"Ant{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||
kwargs=dict(
|
||||
model_cls=AntEnv,
|
||||
maze_task=task_cls,
|
||||
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.ant,
|
||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||
),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
skip_swimmer = False
|
||||
for inhibited in ["Fall", "Push", "Block"]:
|
||||
if inhibited in maze_id:
|
||||
skip_swimmer = True
|
||||
point_scale = task_cls.MAZE_SIZE_SCALING.point
|
||||
if point_scale is not None:
|
||||
# Point
|
||||
gym.envs.register(
|
||||
id=f"Point{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||
kwargs=dict(
|
||||
model_cls=PointEnv,
|
||||
maze_task=task_cls,
|
||||
maze_size_scaling=point_scale,
|
||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||
),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
|
||||
if skip_swimmer:
|
||||
continue
|
||||
ant_scale = task_cls.MAZE_SIZE_SCALING.ant
|
||||
if ant_scale is not None:
|
||||
# Ant
|
||||
gym.envs.register(
|
||||
id=f"Ant{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||
kwargs=dict(
|
||||
model_cls=AntEnv,
|
||||
maze_task=task_cls,
|
||||
maze_size_scaling=ant_scale,
|
||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||
),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
|
||||
# Reacher
|
||||
gym.envs.register(
|
||||
id=f"Reacher{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||
kwargs=dict(
|
||||
model_cls=ReacherEnv,
|
||||
maze_task=task_cls,
|
||||
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer,
|
||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||
),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
|
||||
# Swimmer
|
||||
gym.envs.register(
|
||||
id=f"Swimmer{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||
kwargs=dict(
|
||||
model_cls=SwimmerEnv,
|
||||
maze_task=task_cls,
|
||||
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer,
|
||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||
),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
swimmer_scale = task_cls.MAZE_SIZE_SCALING.swimmer
|
||||
if swimmer_scale is not None:
|
||||
# Reacher
|
||||
gym.envs.register(
|
||||
id=f"Reacher{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||
kwargs=dict(
|
||||
model_cls=ReacherEnv,
|
||||
maze_task=task_cls,
|
||||
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer,
|
||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||
),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
# Swimmer
|
||||
gym.envs.register(
|
||||
id=f"Swimmer{maze_id}-v{i}",
|
||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||
kwargs=dict(
|
||||
model_cls=SwimmerEnv,
|
||||
maze_task=task_cls,
|
||||
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer,
|
||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||
),
|
||||
max_episode_steps=1000,
|
||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||
)
|
||||
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
@ -48,9 +48,9 @@ class MazeGoal:
|
||||
|
||||
|
||||
class Scaling(NamedTuple):
|
||||
ant: float
|
||||
point: float
|
||||
swimmer: float
|
||||
ant: Optional[float]
|
||||
point: Optional[float]
|
||||
swimmer: Optional[float]
|
||||
|
||||
|
||||
class MazeTask(ABC):
|
||||
@ -330,6 +330,7 @@ class SubGoalTRoom(GoalRewardTRoom):
|
||||
|
||||
|
||||
class GoalRewardBlockMaze(GoalRewardUMaze):
|
||||
MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0, None)
|
||||
OBSERVE_BLOCKS: bool = True
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
@ -357,7 +358,7 @@ class DistRewardBlockMaze(GoalRewardBlockMaze, DistRewardMixIn):
|
||||
class GoalRewardBilliard(MazeTask):
|
||||
REWARD_THRESHOLD: float = 0.9
|
||||
PENALTY: float = -0.0001
|
||||
MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 3.0, 3.0)
|
||||
MAZE_SIZE_SCALING: Scaling = Scaling(None, 3.0, None)
|
||||
OBSERVE_BALLS: bool = True
|
||||
GOAL_SIZE: float = 0.3
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user