diff --git a/mujoco_maze/__init__.py b/mujoco_maze/__init__.py index f92403f..aa9085d 100644 --- a/mujoco_maze/__init__.py +++ b/mujoco_maze/__init__.py @@ -1,5 +1,15 @@ +""" +Mujoco Maze +---------- + +A maze environment using mujoco that supports custom tasks and robots. +""" + + import gym +from mujoco_maze.ant import AntEnv +from mujoco_maze.point import PointEnv from mujoco_maze.maze_task import TaskRegistry @@ -7,8 +17,9 @@ for maze_id in TaskRegistry.keys(): for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)): gym.envs.register( id=f"Ant{maze_id}-v{i}", - entry_point="mujoco_maze.ant_maze_env:AntMazeEnv", + entry_point="mujoco_maze.maze_env:MazeEnv", kwargs=dict( + model_cls=AntEnv, maze_task=task_cls, maze_size_scaling=task_cls.MAZE_SIZE_SCALING.ant, inner_reward_scaling=task_cls.INNER_REWARD_SCALING, @@ -21,8 +32,9 @@ for maze_id in TaskRegistry.keys(): for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)): gym.envs.register( id=f"Point{maze_id}-v{i}", - entry_point="mujoco_maze.point_maze_env:PointMazeEnv", + entry_point="mujoco_maze.maze_env:MazeEnv", kwargs=dict( + model_cls=PointEnv, maze_task=task_cls, maze_size_scaling=task_cls.MAZE_SIZE_SCALING.point, inner_reward_scaling=task_cls.INNER_REWARD_SCALING, diff --git a/mujoco_maze/agent_model.py b/mujoco_maze/agent_model.py index a6b95e5..d063fd3 100644 --- a/mujoco_maze/agent_model.py +++ b/mujoco_maze/agent_model.py @@ -1,4 +1,4 @@ -"""Common API definition for Ant and Point. +"""Common APIs for defining mujoco robot. """ from abc import ABC, abstractmethod @@ -10,6 +10,7 @@ from gym.utils import EzPickle class AgentModel(ABC, MujocoEnv, EzPickle): FILE: str ORI_IND: int + MANUAL_COLLISION: bool def __init__(self, file_path: str, frame_skip: int) -> None: MujocoEnv.__init__(self, file_path, frame_skip) diff --git a/mujoco_maze/ant.py b/mujoco_maze/ant.py index 96afb9b..225dc89 100644 --- a/mujoco_maze/ant.py +++ b/mujoco_maze/ant.py @@ -1,19 +1,10 @@ -# Copyright 2018 The TensorFlow Authors All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== +""" +A four-legged robot as an explorer in the maze. +Based on `models`_ and `gym`_ (both ant and ant-v3). -"""Wrapper for creating the ant environment in gym_mujoco.""" +.. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl +.. _gym: https://github.com/openai/gym +""" import math from typing import Callable, Optional, Tuple @@ -49,6 +40,7 @@ def q_mult(a, b): # multiply two quaternion class AntEnv(AgentModel): FILE: str = "ant.xml" ORI_IND: int = 3 + MANUAL_COLLISION: bool = False def __init__( self, diff --git a/mujoco_maze/ant_maze_env.py b/mujoco_maze/ant_maze_env.py deleted file mode 100644 index 065414d..0000000 --- a/mujoco_maze/ant_maze_env.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2018 The TensorFlow Authors All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from mujoco_maze.ant import AntEnv -from mujoco_maze.maze_env import MazeEnv - - -class AntMazeEnv(MazeEnv): - MODEL_CLASS = AntEnv diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index edc39bb..6091133 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -1,19 +1,10 @@ -# Copyright 2018 The TensorFlow Authors All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== +""" +Mujoco Maze environment. +Based on `models`_ and `rllab`_. -"""Adapted from rllab maze_env.py.""" +.. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl +.. _rllab: https://github.com/rll/rllab +""" import itertools as it import os @@ -32,11 +23,9 @@ MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets" class MazeEnv(gym.Env): - MODEL_CLASS: Type[AgentModel] = AgentModel - MANUAL_COLLISION: bool = False - def __init__( self, + model_cls: Type[AgentModel], maze_task: Type[maze_task.MazeTask] = maze_task.SingleGoalSparseUMaze, n_bins: int = 0, sensor_range: float = 3.0, @@ -50,7 +39,7 @@ class MazeEnv(gym.Env): ) -> None: self._task = maze_task(maze_size_scaling) - xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE) + xml_path = os.path.join(MODEL_DIR, model_cls.FILE) tree = ET.parse(xml_path) worldbody = tree.find(".//worldbody") @@ -260,7 +249,7 @@ class MazeEnv(gym.Env): _, file_path = tempfile.mkstemp(text=True, suffix=".xml") tree.write(file_path) self.world_tree = tree - self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs) + self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs) self.observation_space = self._get_obs_space() def get_ori(self) -> float: @@ -541,7 +530,7 @@ class MazeEnv(gym.Env): def step(self, action): self.t += 1 - if self.MANUAL_COLLISION: + if self.wrapped_env.MANUAL_COLLISION: old_pos = self.wrapped_env.get_xy() inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) new_pos = self.wrapped_env.get_xy() diff --git a/mujoco_maze/maze_env_utils.py b/mujoco_maze/maze_env_utils.py index d3c45a7..3e432bd 100644 --- a/mujoco_maze/maze_env_utils.py +++ b/mujoco_maze/maze_env_utils.py @@ -1,19 +1,11 @@ -# Copyright 2018 The TensorFlow Authors All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== +""" +Utilities for creating maze. +Based on `models`_ and `rllab`_. + +.. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl +.. _rllab: https://github.com/rll/rllab +""" -"""Adapted from rllab maze_env_utils.py.""" import itertools as it import math from enum import Enum diff --git a/mujoco_maze/maze_task.py b/mujoco_maze/maze_task.py index b208a10..ec6158d 100644 --- a/mujoco_maze/maze_task.py +++ b/mujoco_maze/maze_task.py @@ -1,3 +1,6 @@ +"""Maze tasks that are defined by their map, termination condition, and goals. +""" + from abc import ABC, abstractmethod from typing import Dict, List, NamedTuple, Type diff --git a/mujoco_maze/point.py b/mujoco_maze/point.py index 86351f3..5d1a959 100644 --- a/mujoco_maze/point.py +++ b/mujoco_maze/point.py @@ -1,19 +1,10 @@ -# Copyright 2018 The TensorFlow Authors All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== +""" +A ball-like robot as an explorer in the maze. +Based on `models`_ and `rllab`_. -"""Wrapper for creating the ant environment in gym_mujoco.""" +.. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl +.. _rllab: https://github.com/rll/rllab +""" import math from typing import Optional, Tuple @@ -27,6 +18,7 @@ from mujoco_maze.agent_model import AgentModel class PointEnv(AgentModel): FILE: str = "point.xml" ORI_IND: int = 2 + MANUAL_COLLISION: bool = True VELOCITY_LIMITS: float = 10.0 diff --git a/mujoco_maze/point_maze_env.py b/mujoco_maze/point_maze_env.py deleted file mode 100644 index 6d92cf8..0000000 --- a/mujoco_maze/point_maze_env.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2018 The TensorFlow Authors All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from mujoco_maze.maze_env import MazeEnv -from mujoco_maze.point import PointEnv - - -class PointMazeEnv(MazeEnv): - MODEL_CLASS = PointEnv - MANUAL_COLLISION = True