Scaling=None if the environment is not supported
This commit is contained in:
parent
3384934aff
commit
bf4e5b1e97
@ -16,69 +16,66 @@ from mujoco_maze.swimmer import SwimmerEnv
|
|||||||
|
|
||||||
for maze_id in TaskRegistry.keys():
|
for maze_id in TaskRegistry.keys():
|
||||||
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
|
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
|
||||||
# Point
|
point_scale = task_cls.MAZE_SIZE_SCALING.point
|
||||||
gym.envs.register(
|
if point_scale is not None:
|
||||||
id=f"Point{maze_id}-v{i}",
|
# Point
|
||||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
gym.envs.register(
|
||||||
kwargs=dict(
|
id=f"Point{maze_id}-v{i}",
|
||||||
model_cls=PointEnv,
|
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||||
maze_task=task_cls,
|
kwargs=dict(
|
||||||
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.point,
|
model_cls=PointEnv,
|
||||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
maze_task=task_cls,
|
||||||
),
|
maze_size_scaling=point_scale,
|
||||||
max_episode_steps=1000,
|
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
),
|
||||||
)
|
max_episode_steps=1000,
|
||||||
if "Billiard" in maze_id:
|
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||||
continue
|
)
|
||||||
# Ant
|
|
||||||
gym.envs.register(
|
|
||||||
id=f"Ant{maze_id}-v{i}",
|
|
||||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
|
||||||
kwargs=dict(
|
|
||||||
model_cls=AntEnv,
|
|
||||||
maze_task=task_cls,
|
|
||||||
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.ant,
|
|
||||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
|
||||||
),
|
|
||||||
max_episode_steps=1000,
|
|
||||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
|
||||||
)
|
|
||||||
skip_swimmer = False
|
|
||||||
for inhibited in ["Fall", "Push", "Block"]:
|
|
||||||
if inhibited in maze_id:
|
|
||||||
skip_swimmer = True
|
|
||||||
|
|
||||||
if skip_swimmer:
|
ant_scale = task_cls.MAZE_SIZE_SCALING.ant
|
||||||
continue
|
if ant_scale is not None:
|
||||||
|
# Ant
|
||||||
|
gym.envs.register(
|
||||||
|
id=f"Ant{maze_id}-v{i}",
|
||||||
|
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||||
|
kwargs=dict(
|
||||||
|
model_cls=AntEnv,
|
||||||
|
maze_task=task_cls,
|
||||||
|
maze_size_scaling=ant_scale,
|
||||||
|
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||||
|
),
|
||||||
|
max_episode_steps=1000,
|
||||||
|
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||||
|
)
|
||||||
|
|
||||||
# Reacher
|
swimmer_scale = task_cls.MAZE_SIZE_SCALING.swimmer
|
||||||
gym.envs.register(
|
if swimmer_scale is not None:
|
||||||
id=f"Reacher{maze_id}-v{i}",
|
# Reacher
|
||||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
gym.envs.register(
|
||||||
kwargs=dict(
|
id=f"Reacher{maze_id}-v{i}",
|
||||||
model_cls=ReacherEnv,
|
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||||
maze_task=task_cls,
|
kwargs=dict(
|
||||||
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer,
|
model_cls=ReacherEnv,
|
||||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
maze_task=task_cls,
|
||||||
),
|
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer,
|
||||||
max_episode_steps=1000,
|
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
),
|
||||||
)
|
max_episode_steps=1000,
|
||||||
|
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||||
# Swimmer
|
)
|
||||||
gym.envs.register(
|
# Swimmer
|
||||||
id=f"Swimmer{maze_id}-v{i}",
|
gym.envs.register(
|
||||||
entry_point="mujoco_maze.maze_env:MazeEnv",
|
id=f"Swimmer{maze_id}-v{i}",
|
||||||
kwargs=dict(
|
entry_point="mujoco_maze.maze_env:MazeEnv",
|
||||||
model_cls=SwimmerEnv,
|
kwargs=dict(
|
||||||
maze_task=task_cls,
|
model_cls=SwimmerEnv,
|
||||||
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer,
|
maze_task=task_cls,
|
||||||
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer,
|
||||||
),
|
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
|
||||||
max_episode_steps=1000,
|
),
|
||||||
reward_threshold=task_cls.REWARD_THRESHOLD,
|
max_episode_steps=1000,
|
||||||
)
|
reward_threshold=task_cls.REWARD_THRESHOLD,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
@ -48,9 +48,9 @@ class MazeGoal:
|
|||||||
|
|
||||||
|
|
||||||
class Scaling(NamedTuple):
|
class Scaling(NamedTuple):
|
||||||
ant: float
|
ant: Optional[float]
|
||||||
point: float
|
point: Optional[float]
|
||||||
swimmer: float
|
swimmer: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
class MazeTask(ABC):
|
class MazeTask(ABC):
|
||||||
@ -330,6 +330,7 @@ class SubGoalTRoom(GoalRewardTRoom):
|
|||||||
|
|
||||||
|
|
||||||
class GoalRewardBlockMaze(GoalRewardUMaze):
|
class GoalRewardBlockMaze(GoalRewardUMaze):
|
||||||
|
MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0, None)
|
||||||
OBSERVE_BLOCKS: bool = True
|
OBSERVE_BLOCKS: bool = True
|
||||||
|
|
||||||
def __init__(self, scale: float) -> None:
|
def __init__(self, scale: float) -> None:
|
||||||
@ -357,7 +358,7 @@ class DistRewardBlockMaze(GoalRewardBlockMaze, DistRewardMixIn):
|
|||||||
class GoalRewardBilliard(MazeTask):
|
class GoalRewardBilliard(MazeTask):
|
||||||
REWARD_THRESHOLD: float = 0.9
|
REWARD_THRESHOLD: float = 0.9
|
||||||
PENALTY: float = -0.0001
|
PENALTY: float = -0.0001
|
||||||
MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 3.0, 3.0)
|
MAZE_SIZE_SCALING: Scaling = Scaling(None, 3.0, None)
|
||||||
OBSERVE_BALLS: bool = True
|
OBSERVE_BALLS: bool = True
|
||||||
GOAL_SIZE: float = 0.3
|
GOAL_SIZE: float = 0.3
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user