diff --git a/mujoco_maze/__init__.py b/mujoco_maze/__init__.py index 5f00d6e..fcab7a0 100644 --- a/mujoco_maze/__init__.py +++ b/mujoco_maze/__init__.py @@ -40,7 +40,7 @@ for maze_id in MAZE_IDS: entry_point="mujoco_maze.point_maze_env:PointMazeEnv", kwargs=dict(**_get_kwargs(maze_id), dense_reward=False), max_episode_steps=1000, - reward_threshold=0.9 + reward_threshold=0.9, ) diff --git a/mujoco_maze/agent_model.py b/mujoco_maze/agent_model.py index 63fbd3d..5436c96 100644 --- a/mujoco_maze/agent_model.py +++ b/mujoco_maze/agent_model.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from gym.envs.mujoco.mujoco_env import MujocoEnv from gym.utils import EzPickle +from mujoco_py import MjSimState import numpy as np @@ -14,6 +15,15 @@ class AgentModel(ABC, MujocoEnv, EzPickle): MujocoEnv.__init__(self, file_path, frame_skip) EzPickle.__init__(self) + def set_state_without_forward(self, qpos, qvel): + assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) + old_state = self.sim.get_state() + new_state = MjSimState( + old_state.time, qpos, qvel, old_state.act, old_state.udd_state + ) + self.sim.set_state(new_state) + self.sim.forward() + @abstractmethod def _get_obs(self) -> np.ndarray: """Returns the observation from the model. diff --git a/mujoco_maze/ant.py b/mujoco_maze/ant.py index d1bb338..1dcabc5 100644 --- a/mujoco_maze/ant.py +++ b/mujoco_maze/ant.py @@ -137,7 +137,7 @@ class AntEnv(AgentModel): qpos[1] = xy[1] qvel = self.sim.data.qvel - self.set_state(qpos, qvel) + self.set_state_without_forwarding(qpos, qvel) def get_xy(self): - return self.sim.data.qpos[:2] + return np.copy(self.sim.data.qpos[:2]) diff --git a/mujoco_maze/assets/point.xml b/mujoco_maze/assets/point.xml index 8cd89de..2d90f75 100755 --- a/mujoco_maze/assets/point.xml +++ b/mujoco_maze/assets/point.xml @@ -16,8 +16,8 @@ - - + + diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py index 7b60293..3a8b15b 100644 --- a/mujoco_maze/maze_env.py +++ b/mujoco_maze/maze_env.py @@ -70,6 +70,7 @@ class MazeEnv(gym.Env): self._observe_blocks = observe_blocks self._put_spin_near_agent = put_spin_near_agent self._top_down_view = top_down_view + self._collision_coef = 0.1 self._maze_structure = structure = maze_env_utils.construct_maze( maze_id=self._maze_id @@ -164,7 +165,11 @@ class MazeEnv(gym.Env): spinning = maze_env_utils.can_spin(struct) shrink = 0.1 if spinning else 0.99 if falling else 1.0 height_shrink = 0.1 if spinning else 1.0 - x = j * size_scaling - torso_x + 0.25 * size_scaling if spinning else 0.0 + x = ( + j * size_scaling - torso_x + 0.25 * size_scaling + if spinning + else 0.0 + ) y = i * size_scaling - torso_y h = height / 2 * size_scaling * height_shrink size = 0.5 * size_scaling * shrink + self.SIZE_EPS @@ -530,7 +535,7 @@ class MazeEnv(gym.Env): old_pos = self.wrapped_env.get_xy() inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) new_pos = self.wrapped_env.get_xy() - if self._collision.is_in(old_pos, new_pos): + if self._collision.is_in(new_pos): self.wrapped_env.set_xy(old_pos) else: inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action) diff --git a/mujoco_maze/maze_env_utils.py b/mujoco_maze/maze_env_utils.py index d59b922..b5f50e8 100644 --- a/mujoco_maze/maze_env_utils.py +++ b/mujoco_maze/maze_env_utils.py @@ -104,7 +104,7 @@ class Collision: """ ARROUND = np.array([[-1, 0], [1, 0], [0, -1], [0, 1]]) - OFFSET = {False: 0.5, True: 0.55} + OFFSET = {False: 0.48, True: 0.51} def __init__( self, structure: list, size_scaling: float, torso_x: float, torso_y: float, @@ -134,11 +134,11 @@ class Collision: max_x = x_base + size_scaling * offset(pos, 3) self.objects.append((min_y, max_y, min_x, max_x)) - def is_in(self, old_pos, new_pos) -> bool: - for x, y in (new_pos, (old_pos + new_pos) / 2): - for min_y, max_y, min_x, max_x in self.objects: - if min_x <= x <= max_x and min_y <= y <= max_y: - return True + def is_in(self, new_pos) -> bool: + x, y = new_pos + for min_y, max_y, min_x, max_x in self.objects: + if min_x <= x <= max_x and min_y <= y <= max_y: + return True return False diff --git a/mujoco_maze/point.py b/mujoco_maze/point.py index 9e530d4..96cf53e 100644 --- a/mujoco_maze/point.py +++ b/mujoco_maze/point.py @@ -78,7 +78,7 @@ class PointEnv(AgentModel): return self._get_obs() def get_xy(self): - return self.sim.data.qpos[:2] + return np.copy(self.sim.data.qpos[:2]) def set_xy(self, xy): qpos = np.copy(self.sim.data.qpos) @@ -86,7 +86,7 @@ class PointEnv(AgentModel): qpos[1] = xy[1] qvel = self.sim.data.qvel - self.set_state(qpos, qvel) + self.set_state_without_forward(qpos, qvel) def get_ori(self): return self.sim.data.qpos[self.ORI_IND]