diff --git a/mujoco_maze/__init__.py b/mujoco_maze/__init__.py
index 5f00d6e..fcab7a0 100644
--- a/mujoco_maze/__init__.py
+++ b/mujoco_maze/__init__.py
@@ -40,7 +40,7 @@ for maze_id in MAZE_IDS:
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
kwargs=dict(**_get_kwargs(maze_id), dense_reward=False),
max_episode_steps=1000,
- reward_threshold=0.9
+ reward_threshold=0.9,
)
diff --git a/mujoco_maze/agent_model.py b/mujoco_maze/agent_model.py
index 63fbd3d..5436c96 100644
--- a/mujoco_maze/agent_model.py
+++ b/mujoco_maze/agent_model.py
@@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from gym.envs.mujoco.mujoco_env import MujocoEnv
from gym.utils import EzPickle
+from mujoco_py import MjSimState
import numpy as np
@@ -14,6 +15,15 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
MujocoEnv.__init__(self, file_path, frame_skip)
EzPickle.__init__(self)
+ def set_state_without_forward(self, qpos, qvel):
+ assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
+ old_state = self.sim.get_state()
+ new_state = MjSimState(
+ old_state.time, qpos, qvel, old_state.act, old_state.udd_state
+ )
+ self.sim.set_state(new_state)
+ self.sim.forward()
+
@abstractmethod
def _get_obs(self) -> np.ndarray:
"""Returns the observation from the model.
diff --git a/mujoco_maze/ant.py b/mujoco_maze/ant.py
index d1bb338..1dcabc5 100644
--- a/mujoco_maze/ant.py
+++ b/mujoco_maze/ant.py
@@ -137,7 +137,7 @@ class AntEnv(AgentModel):
qpos[1] = xy[1]
qvel = self.sim.data.qvel
- self.set_state(qpos, qvel)
+ self.set_state_without_forwarding(qpos, qvel)
def get_xy(self):
- return self.sim.data.qpos[:2]
+ return np.copy(self.sim.data.qpos[:2])
diff --git a/mujoco_maze/assets/point.xml b/mujoco_maze/assets/point.xml
index 8cd89de..2d90f75 100755
--- a/mujoco_maze/assets/point.xml
+++ b/mujoco_maze/assets/point.xml
@@ -16,8 +16,8 @@
-
-
+
+
diff --git a/mujoco_maze/maze_env.py b/mujoco_maze/maze_env.py
index 7b60293..3a8b15b 100644
--- a/mujoco_maze/maze_env.py
+++ b/mujoco_maze/maze_env.py
@@ -70,6 +70,7 @@ class MazeEnv(gym.Env):
self._observe_blocks = observe_blocks
self._put_spin_near_agent = put_spin_near_agent
self._top_down_view = top_down_view
+ self._collision_coef = 0.1
self._maze_structure = structure = maze_env_utils.construct_maze(
maze_id=self._maze_id
@@ -164,7 +165,11 @@ class MazeEnv(gym.Env):
spinning = maze_env_utils.can_spin(struct)
shrink = 0.1 if spinning else 0.99 if falling else 1.0
height_shrink = 0.1 if spinning else 1.0
- x = j * size_scaling - torso_x + 0.25 * size_scaling if spinning else 0.0
+ x = (
+ j * size_scaling - torso_x + 0.25 * size_scaling
+ if spinning
+ else 0.0
+ )
y = i * size_scaling - torso_y
h = height / 2 * size_scaling * height_shrink
size = 0.5 * size_scaling * shrink + self.SIZE_EPS
@@ -530,7 +535,7 @@ class MazeEnv(gym.Env):
old_pos = self.wrapped_env.get_xy()
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
new_pos = self.wrapped_env.get_xy()
- if self._collision.is_in(old_pos, new_pos):
+ if self._collision.is_in(new_pos):
self.wrapped_env.set_xy(old_pos)
else:
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
diff --git a/mujoco_maze/maze_env_utils.py b/mujoco_maze/maze_env_utils.py
index d59b922..b5f50e8 100644
--- a/mujoco_maze/maze_env_utils.py
+++ b/mujoco_maze/maze_env_utils.py
@@ -104,7 +104,7 @@ class Collision:
"""
ARROUND = np.array([[-1, 0], [1, 0], [0, -1], [0, 1]])
- OFFSET = {False: 0.5, True: 0.55}
+ OFFSET = {False: 0.48, True: 0.51}
def __init__(
self, structure: list, size_scaling: float, torso_x: float, torso_y: float,
@@ -134,11 +134,11 @@ class Collision:
max_x = x_base + size_scaling * offset(pos, 3)
self.objects.append((min_y, max_y, min_x, max_x))
- def is_in(self, old_pos, new_pos) -> bool:
- for x, y in (new_pos, (old_pos + new_pos) / 2):
- for min_y, max_y, min_x, max_x in self.objects:
- if min_x <= x <= max_x and min_y <= y <= max_y:
- return True
+ def is_in(self, new_pos) -> bool:
+ x, y = new_pos
+ for min_y, max_y, min_x, max_x in self.objects:
+ if min_x <= x <= max_x and min_y <= y <= max_y:
+ return True
return False
diff --git a/mujoco_maze/point.py b/mujoco_maze/point.py
index 9e530d4..96cf53e 100644
--- a/mujoco_maze/point.py
+++ b/mujoco_maze/point.py
@@ -78,7 +78,7 @@ class PointEnv(AgentModel):
return self._get_obs()
def get_xy(self):
- return self.sim.data.qpos[:2]
+ return np.copy(self.sim.data.qpos[:2])
def set_xy(self, xy):
qpos = np.copy(self.sim.data.qpos)
@@ -86,7 +86,7 @@ class PointEnv(AgentModel):
qpos[1] = xy[1]
qvel = self.sim.data.qvel
- self.set_state(qpos, qvel)
+ self.set_state_without_forward(qpos, qvel)
def get_ori(self):
return self.sim.data.qpos[self.ORI_IND]