diff --git a/mujoco_maze/__init__.py b/mujoco_maze/__init__.py index 307005a..55bf2e4 100644 --- a/mujoco_maze/__init__.py +++ b/mujoco_maze/__init__.py @@ -3,11 +3,18 @@ import gym MAZE_IDS = ["Maze", "Push", "Fall", "Block", "BlockMaze"] +def _get_kwargs(maze_id: str) -> tuple: + return { + "observe_blocks": maze_id in ["Block", "BlockMaze"], + "pin_spin_near_agent": maze_id in ["Block", "BlockMaze"], + } + + for maze_id in MAZE_IDS: gym.envs.register( id="AntMaze{}-v0".format(maze_id), entry_point="mujoco_maze.ant_maze_env:AntMazeEnv", - kwargs=dict(maze_id=maze_id, manual_collision=True), + kwargs=dict(maze_id=maze_id, maze_size_scaling=8, **_get_kwargs(maze_id)), max_episode_steps=1000, reward_threshold=-1000, ) @@ -16,7 +23,12 @@ for maze_id in MAZE_IDS: gym.envs.register( id="PointMaze{}-v0".format(maze_id), entry_point="mujoco_maze.point_maze_env:PointMazeEnv", - kwargs=dict(maze_id=maze_id, manual_collision=True), + kwargs=dict( + maze_id=maze_id, + maze_size_scaling=4, + manual_collision=True, + **_get_kwargs(maze_id), + ), max_episode_steps=1000, reward_threshold=-1000, ) diff --git a/mujoco_maze/agent_model.py b/mujoco_maze/agent_model.py index d2b7049..627c1b5 100644 --- a/mujoco_maze/agent_model.py +++ b/mujoco_maze/agent_model.py @@ -36,4 +36,3 @@ class AgentModel(ABC, MujocoEnv, EzPickle): @abstractmethod def get_ori(self) -> float: pass - diff --git a/mujoco_maze/ant.py b/mujoco_maze/ant.py index cad281a..d1bb338 100644 --- a/mujoco_maze/ant.py +++ b/mujoco_maze/ant.py @@ -126,7 +126,7 @@ class AntEnv(AgentModel): def get_ori(self): ori = [0, 1, 0, 0] ori_ind = self.ORI_IND - rot = self.sim.data.qpos[ori_ind: ori_ind + 4] # take the quaternion + rot = self.sim.data.qpos[ori_ind : ori_ind + 4] # take the quaternion ori = q_mult(q_mult(rot, ori), q_inv(rot))[1:3] # project onto x-y plane ori = math.atan2(ori[1], ori[0]) return ori diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 912f32f..cb38edb 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -22,7 +22,7 @@ import math import numpy as np import gym -from typing import Type +from typing import Callable, Type, Union from mujoco_maze.agent_model import AgentModel from mujoco_maze import maze_env_utils @@ -49,6 +49,8 @@ class MazeEnv(gym.Env): put_spin_near_agent=False, top_down_view=False, manual_collision=False, + dense_reward=True, + goal_sampler: Union[str, np.ndarray, Callable[[], np.ndarray]] = "default", *args, **kwargs, ): @@ -162,7 +164,7 @@ class MazeEnv(gym.Env): ) elif maze_env_utils.can_move(struct): # Movable block. # The "falling" blocks are shrunk slightly and increased in mass to - # ensure that it can fall easily through a gap in the platform blocks. + # ensure it can fall easily through a gap in the platform blocks. name = "movable_%d_%d" % (i, j) self.movable_blocks.append((name, struct)) falling = maze_env_utils.can_move_z(struct) @@ -265,6 +267,29 @@ class MazeEnv(gym.Env): tree.write(file_path) self.wrapped_env = self.MODEL_CLASS(*args, file_path=file_path, **kwargs) + # Set reward function + self._reward_fn = _reward_fn(maze_id, dense_reward) + + # Set goal sampler + if isinstance(goal_sampler, str): + if goal_sampler == "random": + self._goal_sampler = lambda: np.random.uniform((-4, -4), (20, 20)) + elif goal_sampler == "default": + default_goal = _default_goal(maze_id) + self._goal_sampler = lambda: default_goal + else: + raise NotImplementedError(f"Unknown goal_sampler: {goal_sampler}") + elif isinstance(goal_sampler, np.ndarray): + self._goal_sampler = lambda: goal_sampler + elif callable(goal_sampler): + self._goal_sampler = goal_sampler + else: + raise ValueError(f"Invalid goal_sampler: {goal_sampler}") + self.goal = self._goal_sampler() + + # Set goal function + self._goal_fn = _goal_fn(maze_id) + def get_ori(self): return self.wrapped_env.get_ori() @@ -472,6 +497,8 @@ class MazeEnv(gym.Env): def reset(self): self.t = 0 self.wrapped_env.reset() + # Sample a new goal + self.goal = self._goal_sampler() if len(self._init_positions) > 1: xy = np.random.choice(self._init_positions) self.wrapped_env.set_xy(xy) @@ -529,15 +556,57 @@ class MazeEnv(gym.Env): return True return False + def _is_in_goal(self, pos): + (np.linalg.norm(obs[:3] - goal) <= 0.6) + def step(self, action): self.t += 1 if self._manual_collision: old_pos = self.wrapped_env.get_xy() - inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action) + inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) new_pos = self.wrapped_env.get_xy() if self._is_in_collision(new_pos): self.wrapped_env.set_xy(old_pos) else: - inner_next_obs, inner_reward, done, info = self.wrapped_env.step(action) + inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) next_obs = self._get_obs() - return next_obs, inner_reward, False, info + outer_reward = self._reward_fn(next_obs, self.goal) + done = self._goal_fn(next_obs, self.goal) + return next_obs, inner_reward + outer_reward, done, info + + +def _goal_fn(maze_id: str) -> callable: + if maze_id in ["Maze", "Push"]: + return lambda obs, goal: np.linalg.norm(obs[:2] - goal) <= 0.6 + elif maze_id == "Fall": + return lambda obs, goal: np.linalg.norm(obs[:3] - goal) <= 0.6 + else: + raise NotImplementedError(f"Unknown maze id: {maze_id}") + + +def _reward_fn(maze_id: str, dense: str) -> callable: + if dense: + if maze_id in ["Maze", "Push"]: + return lambda obs, goal: -np.sum(np.square(obs[:2] - goal)) ** 0.5 + elif maze_id == "Fall": + return lambda obs, goal: -np.sum(np.square(obs[:3] - goal)) ** 0.5 + else: + raise NotImplementedError(f"Unknown maze id: {maze_id}") + else: + if maze_id in ["Maze", "Push"]: + return lambda obs, goal: (np.linalg.norm(obs[:2] - goal) <= 0.6) * 1.0 + elif maze_id == "Fall": + return lambda obs, goal: (np.linalg.norm(obs[:3] - goal) <= 0.6) * 1.0 + else: + raise NotImplementedError(f"Unknown maze id: {maze_id}") + + +def _default_goal(maze_id: str) -> np.ndarray: + if maze_id == "Maze": + return np.array([0.0, 8.0]) + elif maze_id == "Push": + return np.array([0.0, 19.0]) + elif maze_id == "Fall": + return np.array([0.0, 27.0, 4.5]) + else: + raise NotImplementedError(f"Unknown maze id: {maze_id}")