Merge pull request #1 from kngwyu/customizable

Introduce MazeTask for customizability
This commit is contained in:
Yuji Kanagawa 2020-08-04 14:32:25 +09:00 committed by GitHub
commit 1a8d5eb3bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 737 additions and 649 deletions

View File

@ -5,23 +5,32 @@ Some maze environments for reinforcement learning(RL) using [mujoco-py] and
Thankfully, this project is based on the code from [rllab] and [tensorflow/models][models].
## Implemeted Environments
## Environments
- Distance based rewards
- AntMaze-v0
- AntPush-v0
- AntFall-v0
- PointMaze-v0
- PointPush-v0
- PointFall-v0
- PointUMaze/AntUmaze
- Goal rewards + step penalty
- AntMaze-v1
- AntPush-v1
- AntFall-v1
- PointMaze-v1
- PointPush-v1
- PointFall-v1
![PointUMaze](./screenshots/PointUMaze.png)
- PointUMaze-v0/AntUMaze-v0 (Distance-based Reward)
- PointUmaze-v1/AntUMaze-v1 (Goal-based Reward i.e., 1.0 or -ε)
- Point4Rooms/Ant4Rooms
![Point4Rooms](./screenshots/Point4Rooms.png)
- Point4Rooms-v0/Ant4Rooms-v0 (Distance-based Reward)
- Point4Rooms-v1/Ant4Rooms-v1 (Goal-based Reward)
- Point4Rooms-v2/Ant4Rooms-v2 (Multiple Goals (0.5 pt or 1.0 pt))
- PointPush/AntPush
![PointPush](./screenshots/PointPush.png)
- PointPush-v0/AntPush-v0 (Distance-based Reward)
- PointPush-v1/AntPush-v1 (Goal-based Reward)
- PointFall/AntFall
![PointFall](./screenshots/PointFall.png)
- PointFall-v0/AntFall-v0 (Distance-based Reward)
- PointFall-v1/AntFall-v1 (Goal-based Reward)
## License
This project is licensed under Apache License, Version 2.0

View File

@ -1,47 +1,46 @@
"""
Mujoco Maze
----------
A maze environment using mujoco that supports custom tasks and robots.
"""
import gym
MAZE_IDS = ["Maze", "Push", "Fall"] # TODO: Block, BlockMaze
from mujoco_maze.ant import AntEnv
from mujoco_maze.maze_task import TaskRegistry
from mujoco_maze.point import PointEnv
for maze_id in TaskRegistry.keys():
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
gym.envs.register(
id=f"Ant{maze_id}-v{i}",
entry_point="mujoco_maze.maze_env:MazeEnv",
kwargs=dict(
model_cls=AntEnv,
maze_task=task_cls,
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.ant,
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
),
max_episode_steps=1000,
reward_threshold=task_cls.REWARD_THRESHOLD,
)
def _get_kwargs(maze_id: str) -> tuple:
return {
"maze_id": maze_id,
"observe_blocks": maze_id in ["Block", "BlockMaze"],
"put_spin_near_agent": maze_id in ["Block", "BlockMaze"],
}
for maze_id in MAZE_IDS:
gym.envs.register(
id="Ant{}-v0".format(maze_id),
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
kwargs=dict(maze_size_scaling=8.0, **_get_kwargs(maze_id)),
max_episode_steps=1000,
reward_threshold=-1000,
)
gym.envs.register(
id="Ant{}-v1".format(maze_id),
entry_point="mujoco_maze.ant_maze_env:AntMazeEnv",
kwargs=dict(maze_size_scaling=8.0, **_get_kwargs(maze_id)),
max_episode_steps=1000,
reward_threshold=0.9,
)
for maze_id in MAZE_IDS:
gym.envs.register(
id="Point{}-v0".format(maze_id),
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
kwargs=_get_kwargs(maze_id),
max_episode_steps=1000,
reward_threshold=-1000,
)
gym.envs.register(
id="Point{}-v1".format(maze_id),
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
kwargs=dict(**_get_kwargs(maze_id), dense_reward=False),
max_episode_steps=1000,
reward_threshold=0.9,
)
for maze_id in TaskRegistry.keys():
for i, task_cls in enumerate(TaskRegistry.tasks(maze_id)):
gym.envs.register(
id=f"Point{maze_id}-v{i}",
entry_point="mujoco_maze.maze_env:MazeEnv",
kwargs=dict(
model_cls=PointEnv,
maze_task=task_cls,
maze_size_scaling=task_cls.MAZE_SIZE_SCALING.point,
inner_reward_scaling=task_cls.INNER_REWARD_SCALING,
),
max_episode_steps=1000,
reward_threshold=task_cls.REWARD_THRESHOLD,
)
__version__ = "0.1.0"

View File

@ -1,29 +1,22 @@
"""Common API definition for Ant and Point.
"""Common APIs for defining mujoco robot.
"""
from abc import ABC, abstractmethod
import numpy as np
from gym.envs.mujoco.mujoco_env import MujocoEnv
from gym.utils import EzPickle
from mujoco_py import MjSimState
import numpy as np
class AgentModel(ABC, MujocoEnv, EzPickle):
FILE: str
ORI_IND: int
MANUAL_COLLISION: bool
RADIUS: float
def __init__(self, file_path: str, frame_skip: int) -> None:
MujocoEnv.__init__(self, file_path, frame_skip)
EzPickle.__init__(self)
def set_state_without_forward(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
old_state = self.sim.get_state()
new_state = MjSimState(
old_state.time, qpos, qvel, old_state.act, old_state.udd_state
)
self.sim.set_state(new_state)
self.sim.forward()
@abstractmethod
def _get_obs(self) -> np.ndarray:
"""Returns the observation from the model.

View File

@ -1,25 +1,28 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
A four-legged robot as an explorer in the maze.
Based on `models`_ and `gym`_ (both ant and ant-v3).
"""Wrapper for creating the ant environment in gym_mujoco."""
.. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl
.. _gym: https://github.com/openai/gym
"""
import math
from typing import Callable, Optional, Tuple
import numpy as np
from mujoco_maze.agent_model import AgentModel
ForwardRewardFn = Callable[[float, float], float]
def forward_reward_vabs(xy_velocity: float) -> float:
return np.sum(np.abs(xy_velocity))
def forward_reward_vnorm(xy_velocity: float) -> float:
return np.linalg.norm(xy_velocity)
def q_inv(a):
return [a[0], -a[1], -a[2], -a[3]]
@ -34,79 +37,49 @@ def q_mult(a, b): # multiply two quaternion
class AntEnv(AgentModel):
FILE = "ant.xml"
ORI_IND = 3
FILE: str = "ant.xml"
ORI_IND: int = 3
MANUAL_COLLISION: bool = False
RADIUS: float = 0.2
def __init__(
self,
file_path=None,
expose_all_qpos=True,
expose_body_coms=None,
expose_body_comvels=None,
):
self._expose_all_qpos = expose_all_qpos
self._expose_body_coms = expose_body_coms
self._expose_body_comvels = expose_body_comvels
self._body_com_indices = {}
self._body_comvel_indices = {}
file_path: Optional[str] = None,
ctrl_cost_weight: float = 0.0001,
forward_reward_fn: ForwardRewardFn = forward_reward_vnorm,
) -> None:
self._ctrl_cost_weight = ctrl_cost_weight
self._forward_reward_fn = forward_reward_fn
super().__init__(file_path, 5)
def _step(self, a):
return self.step(a)
def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]:
xy_pos_after = self.sim.data.qpos[:2].copy()
xy_velocity = (xy_pos_after - xy_pos_before) / self.dt
return self._forward_reward_fn(xy_velocity)
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
xy_pos_before = self.sim.data.qpos[:2].copy()
self.do_simulation(action, self.frame_skip)
forward_reward = self._forward_reward(xy_pos_before)
ctrl_cost = self._ctrl_cost_weight * np.square(action).sum()
def step(self, a):
xposbefore = self.get_body_com("torso")[0]
self.do_simulation(a, self.frame_skip)
xposafter = self.get_body_com("torso")[0]
forward_reward = (xposafter - xposbefore) / self.dt
ctrl_cost = 0.5 * np.square(a).sum()
survive_reward = 1.0
reward = forward_reward - ctrl_cost + survive_reward
_ = self.state_vector()
done = False
ob = self._get_obs()
return (
ob,
reward,
done,
dict(
reward_forward=forward_reward,
reward_ctrl=-ctrl_cost,
reward_survive=survive_reward,
),
forward_reward - ctrl_cost,
False,
dict(reward_forward=forward_reward, reward_ctrl=-ctrl_cost,),
)
def _get_obs(self):
# No cfrc observation
if self._expose_all_qpos:
obs = np.concatenate(
[
self.sim.data.qpos.flat[:15], # Ensures only ant obs.
self.sim.data.qvel.flat[:14],
]
)
else:
obs = np.concatenate(
[self.sim.data.qpos.flat[2:15], self.sim.data.qvel.flat[:14],]
)
if self._expose_body_coms is not None:
for name in self._expose_body_coms:
com = self.get_body_com(name)
if name not in self._body_com_indices:
indices = range(len(obs), len(obs) + len(com))
self._body_com_indices[name] = indices
obs = np.concatenate([obs, com])
if self._expose_body_comvels is not None:
for name in self._expose_body_comvels:
comvel = self.get_body_comvel(name)
if name not in self._body_comvel_indices:
indices = range(len(obs), len(obs) + len(comvel))
self._body_comvel_indices[name] = indices
obs = np.concatenate([obs, comvel])
return obs
return np.concatenate(
[
self.sim.data.qpos.flat[:15], # Ensures only ant obs.
self.sim.data.qvel.flat[:14],
]
)
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(
@ -120,9 +93,6 @@ class AntEnv(AgentModel):
self.set_state(qpos, qvel)
return self._get_obs()
def viewer_setup(self):
self.viewer.cam.distance = self.model.stat.extent * 0.5
def get_ori(self):
ori = [0, 1, 0, 0]
ori_ind = self.ORI_IND
@ -132,12 +102,9 @@ class AntEnv(AgentModel):
return ori
def set_xy(self, xy):
qpos = np.copy(self.sim.data.qpos)
qpos[0] = xy[0]
qpos[1] = xy[1]
qvel = self.sim.data.qvel
self.set_state_without_forwarding(qpos, qvel)
qpos = self.sim.data.qpos.copy()
qpos[:2] = xy
self.set_state(qpos, self.sim.data.qvel)
def get_xy(self):
return np.copy(self.sim.data.qpos[:2])

View File

@ -1,21 +0,0 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from mujoco_maze.maze_env import MazeEnv
from mujoco_maze.ant import AntEnv
class AntMazeEnv(MazeEnv):
MODEL_CLASS = AntEnv

View File

@ -1,86 +1,60 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Mujoco Maze environment.
Based on `models`_ and `rllab`_.
"""Adapted from rllab maze_env.py."""
.. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl
.. _rllab: https://github.com/rll/rllab
"""
import itertools as it
import math
import numpy as np
import gym
import os
import tempfile
import xml.etree.ElementTree as ET
from typing import List, Tuple, Type
from typing import Callable, Type, Union
import gym
import numpy as np
from mujoco_maze import maze_env_utils, maze_task
from mujoco_maze.agent_model import AgentModel
from mujoco_maze import maze_env_utils
# Directory that contains mujoco xml files.
MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/assets"
class MazeEnv(gym.Env):
MODEL_CLASS: Type[AgentModel] = AgentModel
MANUAL_COLLISION: bool = False
# For preventing the point from going through the wall
SIZE_EPS = 0.0001
def __init__(
self,
maze_id=None,
n_bins=0,
sensor_range=3.0,
sensor_span=2 * math.pi,
observe_blocks=False,
put_spin_near_agent=False,
top_down_view=False,
dense_reward=True,
model_cls: Type[AgentModel],
maze_task: Type[maze_task.MazeTask] = maze_task.MazeTask,
top_down_view: float = False,
maze_height: float = 0.5,
maze_size_scaling: float = 4.0,
goal_sampler: Union[str, np.ndarray, Callable[[], np.ndarray]] = "default",
inner_reward_scaling: float = 1.0,
restitution_coef: float = 0.8,
*args,
**kwargs,
) -> None:
self._maze_id = maze_id
self._task = maze_task(maze_size_scaling)
xml_path = os.path.join(MODEL_DIR, self.MODEL_CLASS.FILE)
xml_path = os.path.join(MODEL_DIR, model_cls.FILE)
tree = ET.parse(xml_path)
worldbody = tree.find(".//worldbody")
self._maze_height = height = maze_height
self._maze_size_scaling = size_scaling = maze_size_scaling
self._inner_reward_scaling = inner_reward_scaling
self.t = 0 # time steps
self._n_bins = n_bins
self._sensor_range = sensor_range * size_scaling
self._sensor_span = sensor_span
self._observe_blocks = observe_blocks
self._put_spin_near_agent = put_spin_near_agent
self._observe_blocks = self._task.OBSERVE_BLOCKS
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
self._top_down_view = top_down_view
self._collision_coef = 0.1
self._restitution_coef = restitution_coef
self._maze_structure = structure = maze_env_utils.construct_maze(
maze_id=self._maze_id
)
self._maze_structure = structure = self._task.create_maze()
# Elevate the maze to allow for falling.
self.elevated = any(maze_env_utils.MazeCell.CHASM in row for row in structure)
# Are there any movable blocks?
self.blocks = any(
any(r.can_move() for r in row) for row in structure
)
self.blocks = any(any(r.can_move() for r in row) for row in structure)
torso_x, torso_y = self._find_robot()
self._init_torso_x = torso_x
@ -89,8 +63,8 @@ class MazeEnv(gym.Env):
(x - torso_x, y - torso_y) for x, y in self._find_all_robots()
]
self._collision = maze_env_utils.Collision(
structure, size_scaling, torso_x, torso_y,
self._collision = maze_env_utils.CollisionDetector(
structure, size_scaling, torso_x, torso_y, model_cls.RADIUS,
)
self._xy_to_rowcol = lambda x, y: (
@ -105,7 +79,7 @@ class MazeEnv(gym.Env):
# Increase initial z-pos of ant.
height_offset = height * size_scaling
torso = tree.find(".//body[@name='torso']")
torso.set("pos", "0 0 %.2f" % (0.75 + height_offset))
torso.set("pos", f"0 0 {0.75 + height_offset:.2f}")
if self.blocks:
# If there are movable blocks, change simulation settings to perform
# better contact detection.
@ -117,13 +91,13 @@ class MazeEnv(gym.Env):
for j in range(len(structure[0])):
struct = structure[i][j]
if struct.is_robot() and self._put_spin_near_agent:
struct = maze_env_utils.Move.SpinXY
struct = maze_env_utils.MazeCell.SpinXY
if self.elevated and not struct.is_chasm():
# Create elevated platform.
x = j * size_scaling - torso_x
y = i * size_scaling - torso_y
h = height / 2 * size_scaling
size = 0.5 * size_scaling + self.SIZE_EPS
size = 0.5 * size_scaling
ET.SubElement(
worldbody,
"geom",
@ -142,7 +116,7 @@ class MazeEnv(gym.Env):
x = j * size_scaling - torso_x
y = i * size_scaling - torso_y
h = height / 2 * size_scaling
size = 0.5 * size_scaling + self.SIZE_EPS
size = 0.5 * size_scaling
ET.SubElement(
worldbody,
"geom",
@ -172,7 +146,7 @@ class MazeEnv(gym.Env):
)
y = i * size_scaling - torso_y
h = height / 2 * size_scaling * height_shrink
size = 0.5 * size_scaling * shrink + self.SIZE_EPS
size = 0.5 * size_scaling * shrink
movable_body = ET.SubElement(
worldbody,
"body",
@ -253,37 +227,58 @@ class MazeEnv(gym.Env):
if "name" not in geom.attrib:
raise Exception("Every geom of the torso must have a name " "defined")
# Set goals
asset = tree.find(".//asset")
for i, goal in enumerate(self._task.goals):
ET.SubElement(asset, "material", name=f"goal{i}", rgba=goal.rbga_str())
z = goal.pos[2] if goal.dim >= 3 else 0.0
ET.SubElement(
worldbody,
"site",
name=f"goal_site{i}",
pos=f"{goal.pos[0]} {goal.pos[1]} {z}",
size=f"{maze_size_scaling * 0.1}",
material=f"goal{i}",
)
_, file_path = tempfile.mkstemp(text=True, suffix=".xml")
tree.write(file_path)
self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs)
self.world_tree = tree
self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs)
self.observation_space = self._get_obs_space()
self._debug = False
# Set reward function
self._reward_fn = _reward_fn(maze_id, dense_reward)
# Set goal sampler
if isinstance(goal_sampler, str):
if goal_sampler == "random":
self._goal_sampler = lambda: np.random.uniform((-4, -4), (20, 20))
elif goal_sampler == "default":
default_goal = _default_goal(maze_id, size_scaling)
self._goal_sampler = lambda: default_goal
else:
raise NotImplementedError(f"Unknown goal_sampler: {goal_sampler}")
elif isinstance(goal_sampler, np.ndarray):
self._goal_sampler = lambda: goal_sampler
elif callable(goal_sampler):
self._goal_sampler = goal_sampler
else:
raise ValueError(f"Invalid goal_sampler: {goal_sampler}")
self.goal = self._goal_sampler()
# Set goal function
self._goal_fn = _goal_fn(maze_id)
def get_ori(self):
def get_ori(self) -> float:
return self.wrapped_env.get_ori()
def get_top_down_view(self):
def _get_obs_space(self) -> gym.spaces.Box:
shape = self._get_obs().shape
high = np.inf * np.ones(shape, dtype=np.float32)
low = -high
# Set velocity limits
wrapped_obs_space = self.wrapped_env.observation_space
high[: wrapped_obs_space.shape[0]] = wrapped_obs_space.high
low[: wrapped_obs_space.shape[0]] = wrapped_obs_space.low
# Set coordinate limits
low[0], high[0], low[1], high[1] = self._xy_limits()
# Set orientation limits
return gym.spaces.Box(low, high)
def _xy_limits(self) -> Tuple[float, float, float, float]:
xmin, ymin, xmax, ymax = 100, 100, -100, -100
structure = self._maze_structure
for i, j in it.product(range(len(structure)), range(len(structure[0]))):
if structure[i][j].is_block():
continue
xmin, xmax = min(xmin, j), max(xmax, j)
ymin, ymax = min(ymin, i), max(ymax, i)
x0, y0 = self._init_torso_x, self._init_torso_y
scaling = self._maze_size_scaling
xmin, xmax = (xmin - 0.5) * scaling - x0, (xmax + 0.5) * scaling - x0
ymin, ymax = (ymin - 0.5) * scaling - y0, (ymax + 0.5) * scaling - y0
return xmin, xmax, ymin, ymax
def get_top_down_view(self) -> np.ndarray:
self._view = np.zeros_like(self._view)
def valid(row, col):
@ -373,98 +368,7 @@ class MazeEnv(gym.Env):
return self._view
def get_range_sensor_obs(self):
"""Returns egocentric range sensor observations of maze."""
robot_x, robot_y, robot_z = self.wrapped_env.get_body_com("torso")[:3]
ori = self.get_ori()
structure = self._maze_structure
size_scaling = self._maze_size_scaling
height = self._maze_height
segments = []
# Get line segments (corresponding to outer boundary) of each immovable
# block or drop-off.
for i in range(len(structure)):
for j in range(len(structure[0])):
if structure[i][j].is_wall_or_chasm(): # There's a wall or drop-off.
cx = j * size_scaling - self._init_torso_x
cy = i * size_scaling - self._init_torso_y
x1 = cx - 0.5 * size_scaling
x2 = cx + 0.5 * size_scaling
y1 = cy - 0.5 * size_scaling
y2 = cy + 0.5 * size_scaling
struct_segments = [
((x1, y1), (x2, y1)),
((x2, y1), (x2, y2)),
((x2, y2), (x1, y2)),
((x1, y2), (x1, y1)),
]
for seg in struct_segments:
segments.append(dict(segment=seg, type=structure[i][j],))
# Get line segments (corresponding to outer boundary) of each movable
# block within the agent's z-view.
for block_name, block_type in self.movable_blocks:
block_x, block_y, block_z = self.wrapped_env.get_body_com(block_name)[:3]
if (
block_z + height * size_scaling / 2 >= robot_z
and robot_z >= block_z - height * size_scaling / 2
): # Block in view.
x1 = block_x - 0.5 * size_scaling
x2 = block_x + 0.5 * size_scaling
y1 = block_y - 0.5 * size_scaling
y2 = block_y + 0.5 * size_scaling
struct_segments = [
((x1, y1), (x2, y1)),
((x2, y1), (x2, y2)),
((x2, y2), (x1, y2)),
((x1, y2), (x1, y1)),
]
for seg in struct_segments:
segments.append(dict(segment=seg, type=block_type))
sensor_readings = np.zeros((self._n_bins, 3)) # 3 for wall, drop-off, block
for ray_idx in range(self._n_bins):
ray_ori = (
ori
- self._sensor_span * 0.5
+ (2 * ray_idx + 1.0) / (2 * self._n_bins) * self._sensor_span
)
ray_segments = []
# Get all segments that intersect with ray.
for seg in segments:
p = maze_env_utils.ray_segment_intersect(
ray=((robot_x, robot_y), ray_ori), segment=seg["segment"]
)
if p is not None:
ray_segments.append(
dict(
segment=seg["segment"],
type=seg["type"],
ray_ori=ray_ori,
distance=maze_env_utils.point_distance(
p, (robot_x, robot_y)
),
)
)
if len(ray_segments) > 0:
# Find out which segment is intersected first.
first_seg = sorted(ray_segments, key=lambda x: x["distance"])[0]
seg_type = first_seg["type"]
idx = None
if seg_type == 1:
idx = 0 # Wall
elif seg_type == -1:
idx = 1 # Drop-off
elif seg_type.can_move():
idx == 2 # Block
sr = self._sensor_range
if first_seg["distance"] <= sr:
sensor_readings[ray_idx][idx] = (sr - first_seg["distance"]) / sr
return sensor_readings
def _get_obs(self):
def _get_obs(self) -> np.ndarray:
wrapped_obs = self.wrapped_env._get_obs()
if self._top_down_view:
view = [self.get_top_down_view().flat]
@ -479,21 +383,25 @@ class MazeEnv(gym.Env):
[wrapped_obs[:3]] + additional_obs + [wrapped_obs[3:]]
)
range_sensor_obs = self.get_range_sensor_obs()
return np.concatenate(
[wrapped_obs, range_sensor_obs.flat] + view + [[self.t * 0.001]]
)
return np.concatenate([wrapped_obs, *view, np.array([self.t * 0.001])])
def reset(self):
def reset(self) -> np.ndarray:
self.t = 0
self.wrapped_env.reset()
# Sample a new goal
self.goal = self._goal_sampler()
# Samples a new goal
if self._task.sample_goals():
self.set_marker()
# Samples a new start position
if len(self._init_positions) > 1:
xy = np.random.choice(self._init_positions)
self.wrapped_env.set_xy(xy)
return self._get_obs()
def set_marker(self) -> None:
for i, goal in enumerate(self._task.goals):
idx = self.model.site_name2id(f"goal{i}")
self.data.site_xpos[idx][: len(goal.pos)] = goal.pos
@property
def viewer(self):
return self.wrapped_env.viewer
@ -501,18 +409,11 @@ class MazeEnv(gym.Env):
def render(self, *args, **kwargs):
return self.wrapped_env.render(*args, **kwargs)
@property
def observation_space(self):
shape = self._get_obs().shape
high = np.inf * np.ones(shape)
low = -high
return gym.spaces.Box(low, high)
@property
def action_space(self):
return self.wrapped_env.action_space
def _find_robot(self):
def _find_robot(self) -> Tuple[float, float]:
structure = self._maze_structure
size_scaling = self._maze_size_scaling
for i, j in it.product(range(len(structure)), range(len(structure[0]))):
@ -520,7 +421,7 @@ class MazeEnv(gym.Env):
return j * size_scaling, i * size_scaling
raise ValueError("No robot in maze specification.")
def _find_all_robots(self):
def _find_all_robots(self) -> List[Tuple[float, float]]:
structure = self._maze_structure
size_scaling = self._maze_size_scaling
coords = []
@ -529,62 +430,26 @@ class MazeEnv(gym.Env):
coords.append((j * size_scaling, i * size_scaling))
return coords
def step(self, action):
def step(self, action) -> Tuple[np.ndarray, float, bool, dict]:
self.t += 1
if self.MANUAL_COLLISION:
if self.wrapped_env.MANUAL_COLLISION:
old_pos = self.wrapped_env.get_xy()
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
new_pos = self.wrapped_env.get_xy()
if self._collision.is_in(old_pos, new_pos):
self.wrapped_env.set_xy(old_pos)
# Checks that the new_position is in the wall
collision = self._collision.detect(old_pos, new_pos)
if collision is not None:
pos = collision.point + self._restitution_coef * collision.rest()
if self._collision.detect(old_pos, pos) is not None:
# If pos is also not in the wall, we give up computing the position
self.wrapped_env.set_xy(old_pos)
else:
self.wrapped_env.set_xy(pos)
else:
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
next_obs = self._get_obs()
outer_reward = self._reward_fn(next_obs, self.goal)
done = self._goal_fn(next_obs, self.goal)
inner_reward = self._inner_reward_scaling * inner_reward
outer_reward = self._task.reward(next_obs)
done = self._task.termination(next_obs)
info["position"] = self.wrapped_env.get_xy()
return next_obs, inner_reward + outer_reward, done, info
def _goal_fn(maze_id: str) -> callable:
if maze_id in ["Maze", "Push", "BlockMaze"]:
return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6
elif maze_id == "Fall":
return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6
else:
raise NotImplementedError(f"Unknown maze id: {maze_id}")
def _reward_fn(maze_id: str, dense: str) -> callable:
if dense:
if maze_id in ["Maze", "Push", "BlockMaze"]:
return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5
elif maze_id == "Fall":
return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5
else:
raise NotImplementedError(f"Unknown maze id: {maze_id}")
else:
if maze_id in ["Maze", "Push", "BlockMaze"]:
return (
lambda obs, goal: 1.0
if np.linalg.norm(obs[:2] - goal) <= 0.6
else -0.0001
)
elif maze_id == "Fall":
return (
lambda obs, goal: 1.0
if np.linalg.norm(obs[:3] - goal) <= 0.6
else -0.0001
)
else:
raise NotImplementedError(f"Unknown maze id: {maze_id}")
def _default_goal(maze_id: str, scale: float) -> np.ndarray:
if maze_id == "Maze" or maze_id == "BlockMaze":
return np.array([0.0, 2.0 * scale])
elif maze_id == "Push":
return np.array([0.0, 2.375 * scale])
elif maze_id == "Fall":
return np.array([0.0, 3.375 * scale, 4.5])
else:
raise NotImplementedError(f"Unknown maze id: {maze_id}")

View File

@ -1,24 +1,20 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Utilities for creating maze.
Based on `models`_ and `rllab`_.
.. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl
.. _rllab: https://github.com/rll/rllab
"""
"""Adapted from rllab maze_env_utils.py."""
from enum import Enum
import itertools as it
import math
from enum import Enum
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
Self = Any
Point = np.complex
class MazeCell(Enum):
# Robot: Start position
@ -43,6 +39,9 @@ class MazeCell(Enum):
def is_chasm(self) -> bool:
return self == self.CHASM
def is_empty(self) -> bool:
return self == self.ROBOT or self == self.EMPTY
def is_robot(self) -> bool:
return self == self.ROBOT
@ -77,156 +76,123 @@ class MazeCell(Enum):
return self.can_move_x() or self.can_move_y() or self.can_move_z()
def construct_maze(maze_id="Maze"):
E, B, C, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.CHASM, MazeCell.ROBOT
if maze_id == "Maze":
structure = [
[B, B, B, B, B],
[B, R, E, E, B],
[B, B, B, E, B],
[B, E, E, E, B],
[B, B, B, B, B],
]
elif maze_id == "Push":
structure = [
[B, B, B, B, B],
[B, E, R, B, B],
[B, E, MazeCell.XY, E, B],
[B, B, E, B, B],
[B, B, B, B, B],
]
elif maze_id == "Fall":
structure = [
[B, B, B, B],
[B, R, E, B],
[B, E, MazeCell.YZ, B],
[B, C, C, B],
[B, E, E, B],
[B, B, B, B],
]
elif maze_id == "Block":
structure = [
[B, B, B, B, B],
[B, R, E, E, B],
[B, E, E, E, B],
[B, E, E, E, B],
[B, B, B, B, B],
]
elif maze_id == "BlockMaze":
structure = [
[B, B, B, B],
[B, R, E, B],
[B, B, E, B],
[B, E, E, B],
[B, B, B, B],
]
else:
raise NotImplementedError("The provided MazeId %s is not recognized" % maze_id)
class Line:
def __init__(
self, p1: Union[Sequence[float], Point], p2: Union[Sequence[float], Point],
) -> None:
self.p1 = p1 if isinstance(p1, Point) else np.complex(*p1)
self.p2 = p2 if isinstance(p2, Point) else np.complex(*p2)
self.v1 = self.p2 - self.p1
self.conj_v1 = np.conjugate(self.v1)
return structure
if np.absolute(self.v1) <= 1e-8:
raise ValueError(f"p1({p1}) and p2({p2}) are too close")
def _intersect(self, other: Self) -> bool:
v2 = other.p1 - self.p1
v3 = other.p2 - self.p1
return (self.conj_v1 * v2).imag * (self.conj_v1 * v3).imag <= 0.0
def _projection(self, p: Point) -> Point:
nv1 = -self.v1
nv1_norm = np.absolute(nv1) ** 2
scale = np.real(np.conjugate(p - self.p1) * nv1) / nv1_norm
return self.p1 + nv1 * scale
def reflection(self, p: Point) -> Point:
return p + 2.0 * (self._projection(p) - p)
def distance(self, p: Point) -> float:
return np.absolute(p - self._projection(p))
def intersect(self, other: Self) -> Point:
if self._intersect(other) and other._intersect(self):
return self._cross_point(other)
else:
return None
def _cross_point(self, other: Self) -> Optional[Point]:
v2 = other.p2 - other.p1
v3 = self.p2 - other.p1
a, b = (self.conj_v1 * v2).imag, (self.conj_v1 * v3).imag
return other.p1 + b / a * v2
def __repr__(self) -> str:
x1, y1 = self.p1.real, self.p1.imag
x2, y2 = self.p2.real, self.p2.imag
return f"Line(({x1}, {y1}) -> ({x2}, {y2}))"
class Collision:
def __init__(self, point: Point, reflection: Point) -> None:
self._point = point
self._reflection = reflection
@property
def point(self) -> np.ndarray:
return np.array([self._point.real, self._point.imag])
def rest(self) -> np.ndarray:
p = self._reflection - self._point
return np.array([p.real, p.imag])
class CollisionDetector:
"""For manual collision detection.
"""
ARROUND = np.array([[-1, 0], [1, 0], [0, -1], [0, 1]])
OFFSET = {False: 0.48, True: 0.51}
EPS: float = 0.05
NEIGHBORS: List[Tuple[int, int]] = [[0, -1], [-1, 0], [0, 1], [1, 0]]
def __init__(
self, structure: list, size_scaling: float, torso_x: float, torso_y: float,
self,
structure: list,
size_scaling: float,
torso_x: float,
torso_y: float,
radius: float,
) -> None:
h, w = len(structure), len(structure[0])
self.objects = []
self.lines = []
def is_block(pos) -> bool:
i, j = pos
def is_empty(i, j) -> bool:
if 0 <= i < h and 0 <= j < w:
return structure[i][j].is_block()
return structure[i][j].is_empty()
else:
return False
def offset(pos, index) -> float:
return self.OFFSET[is_block(pos + self.ARROUND[index])]
for i, j in it.product(range(len(structure)), range(len(structure[0]))):
if not structure[i][j].is_block():
continue
pos = np.array([i, j])
y_base = i * size_scaling - torso_y
min_y = y_base - size_scaling * offset(pos, 0)
max_y = y_base + size_scaling * offset(pos, 1)
x_base = j * size_scaling - torso_x
min_x = x_base - size_scaling * offset(pos, 2)
max_x = x_base + size_scaling * offset(pos, 3)
self.objects.append((min_y, max_y, min_x, max_x))
offset = size_scaling * 0.5 + radius
min_y, max_y = y_base - offset, y_base + offset
min_x, max_x = x_base - offset, x_base + offset
for dx, dy in self.NEIGHBORS:
if not is_empty(i + dy, j + dx):
continue
self.lines.append(
Line(
(max_x if dx == 1 else min_x, max_y if dy == 1 else min_y),
(min_x if dx == -1 else max_x, min_y if dy == -1 else max_y),
)
)
def is_in(self, old_pos, new_pos) -> bool:
# Heuristics to prevent the agent from going through the wall
for x, y in ((old_pos + new_pos) / 2, new_pos):
for min_y, max_y, min_x, max_x in self.objects:
if min_x <= x <= max_x and min_y <= y <= max_y:
return True
return False
def line_intersect(pt1, pt2, ptA, ptB):
"""
Taken from https://www.cs.hmc.edu/ACM/lectures/intersections.html
Returns the intersection of Line(pt1,pt2) and Line(ptA,ptB).
"""
DET_TOLERANCE = 0.00000001
# the first line is pt1 + r*(pt2-pt1)
# in component form:
x1, y1 = pt1
x2, y2 = pt2
dx1 = x2 - x1
dy1 = y2 - y1
# the second line is ptA + s*(ptB-ptA)
x, y = ptA
xB, yB = ptB
dx = xB - x
dy = yB - y
DET = -dx1 * dy + dy1 * dx
if math.fabs(DET) < DET_TOLERANCE:
return (0, 0, 0, 0, 0)
# now, the determinant should be OK
DETinv = 1.0 / DET
# find the scalar amount along the "self" segment
r = DETinv * (-dy * (x - x1) + dx * (y - y1))
# find the scalar amount along the input line
s = DETinv * (-dy1 * (x - x1) + dx1 * (y - y1))
# return the average of the two descriptions
xi = (x1 + r * dx1 + x + s * dx) / 2.0
yi = (y1 + r * dy1 + y + s * dy) / 2.0
return (xi, yi, 1, r, s)
def ray_segment_intersect(ray, segment):
"""
Check if the ray originated from (x, y) with direction theta intersect the line
segment (x1, y1) -- (x2, y2), and return the intersection point if there is one.
"""
(x, y), theta = ray
# (x1, y1), (x2, y2) = segment
pt1 = (x, y)
pt2 = (x + math.cos(theta), y + math.sin(theta))
xo, yo, valid, r, s = line_intersect(pt1, pt2, *segment)
if valid and r >= 0 and 0 <= s <= 1:
return (xo, yo)
return None
def point_distance(p1, p2):
x1, y1 = p1
x2, y2 = p2
return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
def detect(self, old_pos: np.ndarray, new_pos: np.ndarray) -> Optional[Collision]:
move = Line(old_pos, new_pos)
# Next, checks that the trajectory cross the wall or not
collisions = []
for line in self.lines:
intersection = line.intersect(move)
if intersection is not None:
reflection = line.reflection(move.p2)
collisions.append(Collision(intersection, reflection))
if len(collisions) == 0:
return None
col = collisions[0]
dist = np.absolute(col._point - move.p1)
for collision in collisions[1:]:
new_dist = np.absolute(collision._point - move.p1)
if new_dist < dist:
col, dist = collision, new_dist
return col

255
mujoco_maze/maze_task.py Normal file
View File

@ -0,0 +1,255 @@
"""Maze tasks that are defined by their map, termination condition, and goals.
"""
from abc import ABC, abstractmethod
from typing import Dict, List, NamedTuple, Type
import numpy as np
from mujoco_maze.maze_env_utils import MazeCell
class Rgb(NamedTuple):
red: float
green: float
blue: float
RED = Rgb(0.7, 0.1, 0.1)
GREEN = Rgb(0.1, 0.7, 0.1)
BLUE = Rgb(0.1, 0.1, 0.7)
class MazeGoal:
THRESHOLD: float = 0.6
def __init__(
self, pos: np.ndarray, reward_scale: float = 1.0, rgb: Rgb = RED
) -> None:
assert 0.0 <= reward_scale <= 1.0
self.pos = pos
self.dim = pos.shape[0]
self.reward_scale = reward_scale
self.rgb = rgb
def rbga_str(self) -> str:
r, g, b = self.rgb
return f"{r} {g} {b} 1"
def neighbor(self, obs: np.ndarray) -> float:
return np.linalg.norm(obs[: self.dim] - self.pos) <= self.THRESHOLD
def euc_dist(self, obs: np.ndarray) -> float:
return np.sum(np.square(obs[: self.dim] - self.pos)) ** 0.5
class Scaling(NamedTuple):
ant: float
point: float
class MazeTask(ABC):
REWARD_THRESHOLD: float
MAZE_SIZE_SCALING: Scaling = Scaling(8.0, 4.0)
INNER_REWARD_SCALING: float = 0.01
OBSERVE_BLOCKS: bool = False
PUT_SPIN_NEAR_AGENT: bool = False
def __init__(self, scale: float) -> None:
self.goals = []
self.scale = scale
def sample_goals(self) -> bool:
return False
def termination(self, obs: np.ndarray) -> bool:
for goal in self.goals:
if goal.neighbor(obs):
return True
return False
@abstractmethod
def reward(self, obs: np.ndarray) -> float:
pass
@staticmethod
@abstractmethod
def create_maze() -> List[List[MazeCell]]:
pass
class DistRewardMixIn:
REWARD_THRESHOLD: float = -1000.0
goals: List[MazeGoal]
scale: float
def reward(self, obs: np.ndarray) -> float:
return -self.goals[0].euc_dist(obs) / self.scale
class GoalRewardUMaze(MazeTask):
REWARD_THRESHOLD: float = 0.9
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals = [MazeGoal(np.array([0.0, 2.0 * scale]))]
def reward(self, obs: np.ndarray) -> float:
return 1.0 if self.termination(obs) else -0.0001
@staticmethod
def create_maze() -> List[List[MazeCell]]:
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
return [
[B, B, B, B, B],
[B, R, E, E, B],
[B, B, B, E, B],
[B, E, E, E, B],
[B, B, B, B, B],
]
class DistRewardUMaze(GoalRewardUMaze, DistRewardMixIn):
pass
class GoalRewardPush(GoalRewardUMaze):
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals = [MazeGoal(np.array([0.0, 2.375 * scale]))]
@staticmethod
def create_maze() -> List[List[MazeCell]]:
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
return [
[B, B, B, B, B],
[B, E, R, B, B],
[B, E, MazeCell.XY, E, B],
[B, B, E, B, B],
[B, B, B, B, B],
]
class DistRewardPush(GoalRewardPush, DistRewardMixIn):
pass
class GoalRewardFall(GoalRewardUMaze):
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals = [MazeGoal(np.array([0.0, 3.375 * scale, 4.5]))]
@staticmethod
def create_maze() -> List[List[MazeCell]]:
E, B, C, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.CHASM, MazeCell.ROBOT
return [
[B, B, B, B],
[B, R, E, B],
[B, E, MazeCell.YZ, B],
[B, C, C, B],
[B, E, E, B],
[B, B, B, B],
]
class DistRewardFall(GoalRewardFall, DistRewardMixIn):
pass
class GoalReward2Rooms(MazeTask):
REWARD_THRESHOLD: float = 0.9
MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 4.0)
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals = [MazeGoal(np.array([0.0, 4.0 * scale]))]
def reward(self, obs: np.ndarray) -> float:
for goal in self.goals:
if goal.neighbor(obs):
return goal.reward_scale
return -0.0001
@staticmethod
def create_maze() -> List[List[MazeCell]]:
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
return [
[B, B, B, B, B, B, B, B],
[B, R, E, E, E, E, E, B],
[B, E, E, E, E, E, E, B],
[B, B, B, B, B, E, B, B],
[B, E, E, E, E, E, E, B],
[B, E, E, E, E, E, E, B],
[B, B, B, B, B, B, B, B],
]
class DistReward2Rooms(GoalReward2Rooms, DistRewardMixIn):
pass
class SubGoal2Rooms(GoalReward2Rooms):
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals.append(MazeGoal(np.array([5.0 * scale, 0.0 * scale]), 0.5, GREEN))
class GoalReward4Rooms(MazeTask):
REWARD_THRESHOLD: float = 0.9
MAZE_SIZE_SCALING: Scaling = Scaling(4.0, 4.0)
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals = [MazeGoal(np.array([6.0 * scale, 6.0 * scale]))]
def reward(self, obs: np.ndarray) -> float:
for goal in self.goals:
if goal.neighbor(obs):
return goal.reward_scale
return -0.0001
@staticmethod
def create_maze() -> List[List[MazeCell]]:
E, B, R = MazeCell.EMPTY, MazeCell.BLOCK, MazeCell.ROBOT
return [
[B, B, B, B, B, B, B, B, B],
[B, R, E, E, B, E, E, E, B],
[B, E, E, E, E, E, E, E, B],
[B, E, E, E, B, E, E, E, B],
[B, B, E, B, B, B, E, B, B],
[B, E, E, E, B, E, E, E, B],
[B, E, E, E, E, E, E, E, B],
[B, E, E, E, B, E, E, E, B],
[B, B, B, B, B, B, B, B, B],
]
class DistReward4Rooms(GoalReward4Rooms, DistRewardMixIn):
pass
class SubGoal4Rooms(GoalReward4Rooms):
def __init__(self, scale: float) -> None:
super().__init__(scale)
self.goals += [
MazeGoal(np.array([0.0 * scale, 6.0 * scale]), 0.5, GREEN),
MazeGoal(np.array([6.0 * scale, 0.0 * scale]), 0.5, GREEN),
]
class TaskRegistry:
REGISTRY: Dict[str, List[Type[MazeTask]]] = {
"UMaze": [DistRewardUMaze, GoalRewardUMaze],
"Push": [DistRewardPush, GoalRewardPush],
"Fall": [DistRewardFall, GoalRewardFall],
"2Rooms": [DistReward2Rooms, GoalReward2Rooms, SubGoal2Rooms],
"4Rooms": [DistReward4Rooms, GoalReward4Rooms, SubGoal4Rooms],
}
@staticmethod
def keys() -> List[str]:
return list(TaskRegistry.REGISTRY.keys())
@staticmethod
def tasks(key: str) -> List[Type[MazeTask]]:
return TaskRegistry.REGISTRY[key]

View File

@ -1,69 +1,62 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
A ball-like robot as an explorer in the maze.
Based on `models`_ and `rllab`_.
"""Wrapper for creating the ant environment in gym_mujoco."""
.. _models: https://github.com/tensorflow/models/tree/master/research/efficient-hrl
.. _rllab: https://github.com/rll/rllab
"""
import math
from typing import Optional, Tuple
import gym
import numpy as np
from mujoco_maze.agent_model import AgentModel
class PointEnv(AgentModel):
FILE = "point.xml"
ORI_IND = 2
FILE: str = "point.xml"
ORI_IND: int = 2
MANUAL_COLLISION: bool = True
RADIUS: float = 0.4
def __init__(self, file_path=None, expose_all_qpos=True):
self._expose_all_qpos = expose_all_qpos
VELOCITY_LIMITS: float = 10.0
def __init__(self, file_path: Optional[str] = None):
super().__init__(file_path, 1)
high = np.inf * np.ones(6, dtype=np.float32)
high[3:] = self.VELOCITY_LIMITS * 1.2
high[self.ORI_IND] = np.pi
low = -high
self.observation_space = gym.spaces.Box(low, high)
def _step(self, a):
return self.step(a)
def step(self, action):
qpos = np.copy(self.sim.data.qpos)
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
qpos = self.sim.data.qpos.copy()
qpos[2] += action[1]
# Clip orientation
if qpos[2] < -np.pi:
qpos[2] += np.pi * 2
elif np.pi < qpos[2]:
qpos[2] -= np.pi * 2
ori = qpos[2]
# compute increment in each direction
dx = math.cos(ori) * action[0]
dy = math.sin(ori) * action[0]
# ensure that the robot is within reasonable range
qpos[0] = np.clip(qpos[0] + dx, -100, 100)
qpos[1] = np.clip(qpos[1] + dy, -100, 100)
qvel = self.sim.data.qvel
# Compute increment in each direction
qpos[0] += math.cos(ori) * action[0]
qpos[1] += math.sin(ori) * action[0]
qvel = np.clip(self.sim.data.qvel, -self.VELOCITY_LIMITS, self.VELOCITY_LIMITS)
self.set_state(qpos, qvel)
for _ in range(0, self.frame_skip):
self.sim.step()
next_obs = self._get_obs()
reward = 0
done = False
info = {}
return next_obs, reward, done, info
return next_obs, 0.0, False, {}
def _get_obs(self):
if self._expose_all_qpos:
return np.concatenate(
[
self.sim.data.qpos.flat[:3], # Only point-relevant coords.
self.sim.data.qvel.flat[:3],
]
)
else:
return np.concatenate(
[self.sim.data.qpos.flat[2:3], self.sim.data.qvel.flat[:3]]
)
return np.concatenate(
[
self.sim.data.qpos.flat[:3], # Only point-relevant coords.
self.sim.data.qvel.flat[:3],
]
)
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(
@ -78,15 +71,12 @@ class PointEnv(AgentModel):
return self._get_obs()
def get_xy(self):
return np.copy(self.sim.data.qpos[:2])
return self.sim.data.qpos[:2].copy()
def set_xy(self, xy):
qpos = np.copy(self.sim.data.qpos)
qpos[0] = xy[0]
qpos[1] = xy[1]
qvel = self.sim.data.qvel
self.set_state_without_forward(qpos, qvel)
def set_xy(self, xy: np.ndarray) -> None:
qpos = self.sim.data.qpos.copy()
qpos[:2] = xy
self.set_state(qpos, self.sim.data.qvel)
def get_ori(self):
return self.sim.data.qpos[self.ORI_IND]

View File

@ -1,22 +0,0 @@
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from mujoco_maze.maze_env import MazeEnv
from mujoco_maze.point import PointEnv
class PointMazeEnv(MazeEnv):
MODEL_CLASS = PointEnv
MANUAL_COLLISION = True

BIN
screenshots/Point4Rooms.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

BIN
screenshots/PointFall.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

BIN
screenshots/PointPush.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
screenshots/PointUMaze.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

View File

@ -1,19 +1,22 @@
import gym
import mujoco_maze
import pytest
import mujoco_maze
@pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS)
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
def test_ant_maze(maze_id):
env = gym.make("Ant{}-v0".format(maze_id))
assert env.reset().shape == (30,)
s, _, _, _ = env.step(env.action_space.sample())
assert s.shape == (30,)
for i in range(2):
env = gym.make(f"Ant{maze_id}-v{i}")
assert env.reset().shape == (30,)
s, _, _, _ = env.step(env.action_space.sample())
assert s.shape == (30,)
@pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS)
@pytest.mark.parametrize("maze_id", mujoco_maze.TaskRegistry.keys())
def test_point_maze(maze_id):
env = gym.make("Point{}-v0".format(maze_id))
assert env.reset().shape == (7,)
s, _, _, _ = env.step(env.action_space.sample())
assert s.shape == (7,)
for i in range(2):
env = gym.make(f"Point{maze_id}-v{i}")
assert env.reset().shape == (7,)
s, _, _, _ = env.step(env.action_space.sample())
assert s.shape == (7,)

84
tests/test_intersect.py Normal file
View File

@ -0,0 +1,84 @@
import numpy as np
import pytest
from mujoco_maze.maze_env_utils import Line
@pytest.mark.parametrize(
"l1, l2, p, ans",
[
((0.0, 0.0), (4.0, 4.0), (1.0, 3.0), 2.0 ** 0.5),
((-3.0, -3.0), (0.0, 1.0), (-3.0, 1.0), 2.4),
],
)
def test_distance(l1, l2, p, ans):
line = Line(l1, l2)
point = np.complex(*p)
assert abs(line.distance(point) - ans) <= 1e-8
@pytest.mark.parametrize(
"l1p1, l1p2, l2p1, l2p2, none",
[
((0.0, 0.0), (1.0, 0.0), (0.0, -1.0), (1.0, 1.0), False),
((1.0, 1.0), (2.0, 3.0), (-1.0, 1.5), (1.5, 1.0), False),
((1.5, 1.5), (2.0, 3.0), (-1.0, 1.5), (1.5, 1.0), True),
((0.0, 0.0), (2.0, 0.0), (1.0, 0.0), (1.0, 3.0), False),
],
)
def test_intersect(l1p1, l1p2, l2p1, l2p2, none):
l1 = Line(l1p1, l1p2)
l2 = Line(l2p1, l2p2)
i1 = l1.intersect(l2)
i2 = line_intersect(l1p1, l1p2, l2p1, l2p2)
if none:
assert i1 is None and i2 is None
else:
assert i1 is not None
i1 = np.array([i1.real, i1.imag])
np.testing.assert_array_almost_equal(i1, np.array(i2))
def line_intersect(pt1, pt2, ptA, ptB):
"""
Taken from https://www.cs.hmc.edu/ACM/lectures/intersections.html
Returns the intersection of Line(pt1,pt2) and Line(ptA,ptB).
"""
import math
DET_TOLERANCE = 0.00000001
# the first line is pt1 + r*(pt2-pt1)
# in component form:
x1, y1 = pt1
x2, y2 = pt2
dx1 = x2 - x1
dy1 = y2 - y1
# the second line is ptA + s*(ptB-ptA)
x, y = ptA
xB, yB = ptB
dx = xB - x
dy = yB - y
DET = -dx1 * dy + dy1 * dx
if math.fabs(DET) < DET_TOLERANCE:
return None
# now, the determinant should be OK
DETinv = 1.0 / DET
# find the scalar amount along the "self" segment
r = DETinv * (-dy * (x - x1) + dx * (y - y1))
# find the scalar amount along the input line
s = DETinv * (-dy1 * (x - x1) + dx1 * (y - y1))
# return the average of the two descriptions
xi = (x1 + r * dx1 + x + s * dx) / 2.0
yi = (y1 + r * dy1 + y + s * dy) / 2.0
if r >= 0 and 0 <= s <= 1:
return xi, yi
else:
return None