diff --git a/mujoco_maze/__init__.py b/mujoco_maze/__init__.py index 1ff8ecd..525fc83 100644 --- a/mujoco_maze/__init__.py +++ b/mujoco_maze/__init__.py @@ -3,20 +3,13 @@ import gym from mujoco_maze.maze_task import TaskRegistry -def _get_kwargs(maze_id: str) -> tuple: - return { - "maze_id": maze_id, - "observe_blocks": maze_id in ["Block", "BlockMaze"], - "put_spin_near_agent": maze_id in ["Block", "BlockMaze"], - } - - for maze_id in TaskRegistry.keys(): for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)): + scaling = task_cls.SCALING.ant gym.envs.register( id=f"Ant{maze_id}-v{i}", entry_point="mujoco_maze.ant_maze_env:AntMazeEnv", - kwargs=dict(maze_task=task_cls, maze_size_scaling=8.0), + kwargs=dict(maze_task=task_cls, maze_size_scaling=scaling), max_episode_steps=1000, reward_threshold=task_cls.REWARD_THRESHOLD, ) @@ -26,7 +19,7 @@ for maze_id in TaskRegistry.keys(): gym.envs.register( id=f"Point{maze_id}-v{i}", entry_point="mujoco_maze.point_maze_env:PointMazeEnv", - kwargs=dict(maze_task=task_cls), + kwargs=dict(maze_task=task_cls, maze_size_scaling=scaling), max_episode_steps=1000, reward_threshold=task_cls.REWARD_THRESHOLD, ) diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 1d2f5a1..007773f 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -41,8 +41,6 @@ class MazeEnv(gym.Env): n_bins: int = 0, sensor_range: float = 3.0, sensor_span: float = 2 * np.pi, - observe_blocks: float = False, - put_spin_near_agent: float = False, top_down_view: float = False, maze_height: float = 0.5, maze_size_scaling: float = 4.0, @@ -61,8 +59,8 @@ class MazeEnv(gym.Env): self._n_bins = n_bins self._sensor_range = sensor_range * size_scaling self._sensor_span = sensor_span - self._observe_blocks = observe_blocks - self._put_spin_near_agent = put_spin_near_agent + self._observe_blocks = self._task.OBSERVE_BLOCKS + self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT self._top_down_view = top_down_view self._collision_coef = 0.1 diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index f2a04e3..a053ecd 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -40,8 +40,16 @@ class MazeGoal: return np.sum(np.square(obs[: self.dim] - self.pos)) ** 0.5 +class Scaling(NamedTuple): + ant: float + point: float + + class MazeTask(ABC): REWARD_THRESHOLD: float + SCALING: Scaling = Scaling(8.0, 4.0) + OBSERVE_BLOCKS: bool = False + PUT_SPIN_NEAR_AGENT: bool = False def __init__(self, scale: float) -> None: self.scale = scale