Remove tensorflow headers and Add doc comments

This commit is contained in:
kngwyu 2020-07-01 14:12:06 +09:00
parent 91249105b8
commit 14351830bb
9 changed files with 50 additions and 112 deletions

View File

@ -1,5 +1,15 @@
"""
Mujoco Maze
----------
A maze environment using mujoco that supports custom tasks and robots.
"""
import gym import gym
from mujoco_maze.ant import AntEnv
from mujoco_maze.point import PointEnv
from mujoco_maze.maze_task import TaskRegistry from mujoco_maze.maze_task import TaskRegistry
@ -7,8 +17,9 @@ for maze_id in TaskRegistry.keys():
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)): for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
gym.envs.register( gym.envs.register(
id=f"Ant{maze_id}-v{i}", id=f"Ant{maze_id}-v{i}",
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv", entry_point="mujoco_maze.maze_env:MazeEnv",
kwargs=dict( kwargs=dict(
model_cls=AntEnv,
maze_task=task_cls, maze_task=task_cls,
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.ant, maze_size_scaling=task_cls.MAZE_SIZE_SCALING.ant,
inner_reward_scaling=task_cls.INNER_REWARD_SCALING, inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
@ -21,8 +32,9 @@ for maze_id in TaskRegistry.keys():
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)): for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
gym.envs.register( gym.envs.register(
id=f"Point{maze_id}-v{i}", id=f"Point{maze_id}-v{i}",
entry_point="mujoco_maze.point_maze_env:PointMazeEnv", entry_point="mujoco_maze.maze_env:MazeEnv",
kwargs=dict( kwargs=dict(
model_cls=PointEnv,
maze_task=task_cls, maze_task=task_cls,
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.point, maze_size_scaling=task_cls.MAZE_SIZE_SCALING.point,
inner_reward_scaling=task_cls.INNER_REWARD_SCALING, inner_reward_scaling=task_cls.INNER_REWARD_SCALING,

View File

@ -1,4 +1,4 @@
"""Common API definition for Ant and Point. """Common APIs for defining mujoco robot.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -10,6 +10,7 @@ from gym.utils import EzPickle
class AgentModel(ABC, MujocoEnv, EzPickle): class AgentModel(ABC, MujocoEnv, EzPickle):
FILE: str FILE: str
ORI_IND: int ORI_IND: int
MANUAL_COLLISION: bool
def __init__(self, file_path: str, frame_skip: int) -> None: def __init__(self, file_path: str, frame_skip: int) -> None:
MujocoEnv.__init__(self, file_path, frame_skip) MujocoEnv.__init__(self, file_path, frame_skip)

View File

@ -1,19 +1,10 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved. """
# A four-legged robot as an explorer in the maze.
# Licensed under the Apache License, Version 2.0 (the "License"); Based on `models`_ and `gym`_ (both ant and ant-v3).
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for creating the ant environment in gym_mujoco.""" .. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl
.. _gym: https://github.com/openai/gym
"""
import math import math
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
@ -49,6 +40,7 @@ def q_mult(a, b): # multiply two quaternion
class AntEnv(AgentModel): class AntEnv(AgentModel):
FILE: str = "ant.xml" FILE: str = "ant.xml"
ORI_IND: int = 3 ORI_IND: int = 3
MANUAL_COLLISION: bool = False
def __init__( def __init__(
self, self,

View File

@ -1,21 +0,0 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from mujoco_maze.ant import AntEnv
from mujoco_maze.maze_env import MazeEnv
class AntMazeEnv(MazeEnv):
MODEL_CLASS = AntEnv

View File

@ -1,19 +1,10 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved. """
# Mujoco Maze environment.
# Licensed under the Apache License, Version 2.0 (the "License"); Based on `models`_ and `rllab`_.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adapted from rllab maze_env.py.""" .. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl
.. _rllab: https://github.com/rll/rllab
"""
import itertools as it import itertools as it
import os import os
@ -32,11 +23,9 @@ MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets"
class MazeEnv(gym.Env): class MazeEnv(gym.Env):
MODEL_CLASS: Type[AgentModel] = AgentModel
MANUAL_COLLISION: bool = False
def __init__( def __init__(
self, self,
model_cls: Type[AgentModel],
maze_task: Type[maze_task.MazeTask] = maze_task.SingleGoalSparseUMaze, maze_task: Type[maze_task.MazeTask] = maze_task.SingleGoalSparseUMaze,
n_bins: int = 0, n_bins: int = 0,
sensor_range: float = 3.0, sensor_range: float = 3.0,
@ -50,7 +39,7 @@ class MazeEnv(gym.Env):
) -> None: ) -> None:
self._task = maze_task(maze_size_scaling) self._task = maze_task(maze_size_scaling)
xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE) xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
tree = ET.parse(xml_path) tree = ET.parse(xml_path)
worldbody = tree.find(".//worldbody") worldbody = tree.find(".//worldbody")
@ -260,7 +249,7 @@ class MazeEnv(gym.Env):
_, file_path = tempfile.mkstemp(text=True, suffix=".xml") _, file_path = tempfile.mkstemp(text=True, suffix=".xml")
tree.write(file_path) tree.write(file_path)
self.world_tree = tree self.world_tree = tree
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs) self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
self.observation_space = self._get_obs_space() self.observation_space = self._get_obs_space()
def get_ori(self) -> float: def get_ori(self) -> float:
@ -541,7 +530,7 @@ class MazeEnv(gym.Env):
def step(self, action): def step(self, action):
self.t += 1 self.t += 1
if self.MANUAL_COLLISION: if self.wrapped_env.MANUAL_COLLISION:
old_pos = self.wrapped_env.get_xy() old_pos = self.wrapped_env.get_xy()
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
new_pos = self.wrapped_env.get_xy() new_pos = self.wrapped_env.get_xy()

View File

@ -1,19 +1,11 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved. """
# Utilities for creating maze.
# Licensed under the Apache License, Version 2.0 (the "License"); Based on `models`_ and `rllab`_.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at .. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl
# .. _rllab: https://github.com/rll/rllab
# http://www.apache.org/licenses/LICENSE-2.0 """
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adapted from rllab maze_env_utils.py."""
import itertools as it import itertools as it
import math import math
from enum import Enum from enum import Enum

View File

@ -1,3 +1,6 @@
"""Maze tasks that are defined by their map, termination condition, and goals.
"""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, NamedTuple, Type from typing import Dict, List, NamedTuple, Type

View File

@ -1,19 +1,10 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved. """
# A ball-like robot as an explorer in the maze.
# Licensed under the Apache License, Version 2.0 (the "License"); Based on `models`_ and `rllab`_.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for creating the ant environment in gym_mujoco.""" .. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl
.. _rllab: https://github.com/rll/rllab
"""
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple
@ -27,6 +18,7 @@ from mujoco_maze.agent_model import AgentModel
class PointEnv(AgentModel): class PointEnv(AgentModel):
FILE: str = "point.xml" FILE: str = "point.xml"
ORI_IND: int = 2 ORI_IND: int = 2
MANUAL_COLLISION: bool = True
VELOCITY_LIMITS: float = 10.0 VELOCITY_LIMITS: float = 10.0

View File

@ -1,22 +0,0 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from mujoco_maze.maze_env import MazeEnv
from mujoco_maze.point import PointEnv
class PointMazeEnv(MazeEnv):
MODEL_CLASS = PointEnv
MANUAL_COLLISION = True