Scaling=None if the environment is not supported

This commit is contained in:
kngwyu 2020-10-05 13:52:21 +09:00
parent 3384934aff
commit bf4e5b1e97
2 changed files with 63 additions and 65 deletions

View File

@ -16,69 +16,66 @@ from mujoco_maze.swimmer import SwimmerEnv
for maze_id in TaskRegistry.keys(): for maze_id in TaskRegistry.keys():
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)): for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
# Point point_scale = task_cls.MAZE_SIZE_SCALING.point
gym.envs.register( if point_scale is not None:
id=f"Point{maze_id}-v{i}", # Point
entry_point="mujoco_maze.maze_env:MazeEnv", gym.envs.register(
kwargs=dict( id=f"Point{maze_id}-v{i}",
model_cls=PointEnv, entry_point="mujoco_maze.maze_env:MazeEnv",
maze_task=task_cls, kwargs=dict(
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.point, model_cls=PointEnv,
inner_reward_scaling=task_cls.INNER_REWARD_SCALING, maze_task=task_cls,
), maze_size_scaling=point_scale,
max_episode_steps=1000, inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
reward_threshold=task_cls.REWARD_THRESHOLD, ),
) max_episode_steps=1000,
if "Billiard" in maze_id: reward_threshold=task_cls.REWARD_THRESHOLD,
continue )
# Ant
gym.envs.register(
id=f"Ant{maze_id}-v{i}",
entry_point="mujoco_maze.maze_env:MazeEnv",
kwargs=dict(
model_cls=AntEnv,
maze_task=task_cls,
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.ant,
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
),
max_episode_steps=1000,
reward_threshold=task_cls.REWARD_THRESHOLD,
)
skip_swimmer = False
for inhibited in ["Fall", "Push", "Block"]:
if inhibited in maze_id:
skip_swimmer = True
if skip_swimmer: ant_scale = task_cls.MAZE_SIZE_SCALING.ant
continue if ant_scale is not None:
# Ant
gym.envs.register(
id=f"Ant{maze_id}-v{i}",
entry_point="mujoco_maze.maze_env:MazeEnv",
kwargs=dict(
model_cls=AntEnv,
maze_task=task_cls,
maze_size_scaling=ant_scale,
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
),
max_episode_steps=1000,
reward_threshold=task_cls.REWARD_THRESHOLD,
)
# Reacher swimmer_scale = task_cls.MAZE_SIZE_SCALING.swimmer
gym.envs.register( if swimmer_scale is not None:
id=f"Reacher{maze_id}-v{i}", # Reacher
entry_point="mujoco_maze.maze_env:MazeEnv", gym.envs.register(
kwargs=dict( id=f"Reacher{maze_id}-v{i}",
model_cls=ReacherEnv, entry_point="mujoco_maze.maze_env:MazeEnv",
maze_task=task_cls, kwargs=dict(
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer, model_cls=ReacherEnv,
inner_reward_scaling=task_cls.INNER_REWARD_SCALING, maze_task=task_cls,
), maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer,
max_episode_steps=1000, inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
reward_threshold=task_cls.REWARD_THRESHOLD, ),
) max_episode_steps=1000,
reward_threshold=task_cls.REWARD_THRESHOLD,
# Swimmer )
gym.envs.register( # Swimmer
id=f"Swimmer{maze_id}-v{i}", gym.envs.register(
entry_point="mujoco_maze.maze_env:MazeEnv", id=f"Swimmer{maze_id}-v{i}",
kwargs=dict( entry_point="mujoco_maze.maze_env:MazeEnv",
model_cls=SwimmerEnv, kwargs=dict(
maze_task=task_cls, model_cls=SwimmerEnv,
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer, maze_task=task_cls,
inner_reward_scaling=task_cls.INNER_REWARD_SCALING, maze_size_scaling=task_cls.MAZE_SIZE_SCALING.swimmer,
), inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
max_episode_steps=1000, ),
reward_threshold=task_cls.REWARD_THRESHOLD, max_episode_steps=1000,
) reward_threshold=task_cls.REWARD_THRESHOLD,
)
__version__ = "0.1.0" __version__ = "0.1.0"

View File

@ -48,9 +48,9 @@ class MazeGoal:
class Scaling(NamedTuple): class Scaling(NamedTuple):
ant: float ant: Optional[float]
point: float point: Optional[float]
swimmer: float swimmer: Optional[float]
class MazeTask(ABC): class MazeTask(ABC):
@ -330,6 +330,7 @@ class SubGoalTRoom(GoalRewardTRoom):
class GoalRewardBlockMaze(GoalRewardUMaze): class GoalRewardBlockMaze(GoalRewardUMaze):
MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0, None)
OBSERVE_BLOCKS: bool = True OBSERVE_BLOCKS: bool = True
def __init__(self, scale: float) -> None: def __init__(self, scale: float) -> None:
@ -357,7 +358,7 @@ class DistRewardBlockMaze(GoalRewardBlockMaze, DistRewardMixIn):
class GoalRewardBilliard(MazeTask): class GoalRewardBilliard(MazeTask):
REWARD_THRESHOLD: float = 0.9 REWARD_THRESHOLD: float = 0.9
PENALTY: float = -0.0001 PENALTY: float = -0.0001
MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 3.0, 3.0) MAZE_SIZE_SCALING: Scaling = Scaling(None, 3.0, None)
OBSERVE_BALLS: bool = True OBSERVE_BALLS: bool = True
GOAL_SIZE: float = 0.3 GOAL_SIZE: float = 0.3