Fix collision detection
This commit is contained in:
parent
bd7dd5bcfc
commit
33032dd48e
@ -40,7 +40,7 @@ for maze_id in MAZE_IDS:
|
|||||||
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
entry_point="mujoco_maze.point_maze_env:PointMazeEnv",
|
||||||
kwargs=dict(**_get_kwargs(maze_id), dense_reward=False),
|
kwargs=dict(**_get_kwargs(maze_id), dense_reward=False),
|
||||||
max_episode_steps=1000,
|
max_episode_steps=1000,
|
||||||
reward_threshold=0.9
|
reward_threshold=0.9,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from gym.envs.mujoco.mujoco_env import MujocoEnv
|
from gym.envs.mujoco.mujoco_env import MujocoEnv
|
||||||
from gym.utils import EzPickle
|
from gym.utils import EzPickle
|
||||||
|
from mujoco_py import MjSimState
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@ -14,6 +15,15 @@ class AgentModel(ABC, MujocoEnv, EzPickle):
|
|||||||
MujocoEnv.__init__(self, file_path, frame_skip)
|
MujocoEnv.__init__(self, file_path, frame_skip)
|
||||||
EzPickle.__init__(self)
|
EzPickle.__init__(self)
|
||||||
|
|
||||||
|
def set_state_without_forward(self, qpos, qvel):
|
||||||
|
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
|
||||||
|
old_state = self.sim.get_state()
|
||||||
|
new_state = MjSimState(
|
||||||
|
old_state.time, qpos, qvel, old_state.act, old_state.udd_state
|
||||||
|
)
|
||||||
|
self.sim.set_state(new_state)
|
||||||
|
self.sim.forward()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_obs(self) -> np.ndarray:
|
def _get_obs(self) -> np.ndarray:
|
||||||
"""Returns the observation from the model.
|
"""Returns the observation from the model.
|
||||||
|
@ -137,7 +137,7 @@ class AntEnv(AgentModel):
|
|||||||
qpos[1] = xy[1]
|
qpos[1] = xy[1]
|
||||||
|
|
||||||
qvel = self.sim.data.qvel
|
qvel = self.sim.data.qvel
|
||||||
self.set_state(qpos, qvel)
|
self.set_state_without_forwarding(qpos, qvel)
|
||||||
|
|
||||||
def get_xy(self):
|
def get_xy(self):
|
||||||
return self.sim.data.qpos[:2]
|
return np.copy(self.sim.data.qpos[:2])
|
||||||
|
@ -16,8 +16,8 @@
|
|||||||
<light directional="true" cutoff="100" exponent="1" diffuse="1 1 1" specular=".1 .1 .1" pos="0 0 1.3" dir="-0 0 -1.3" />
|
<light directional="true" cutoff="100" exponent="1" diffuse="1 1 1" specular=".1 .1 .1" pos="0 0 1.3" dir="-0 0 -1.3" />
|
||||||
<geom name="floor" material="MatPlane" pos="0 0 0" size="40 40 40" type="plane" conaffinity="1" rgba="0.8 0.9 0.8 1" condim="3" />
|
<geom name="floor" material="MatPlane" pos="0 0 0" size="40 40 40" type="plane" conaffinity="1" rgba="0.8 0.9 0.8 1" condim="3" />
|
||||||
<body name="torso" pos="0 0 0">
|
<body name="torso" pos="0 0 0">
|
||||||
<geom name="pointbody" type="sphere" size="0.5" pos="0 0 0.5" solimp="0.9995 0.9999 0.001" />
|
<geom name="pointbody" type="sphere" size="0.5" pos="0 0 0.5" solimp="0.9 0.99 0.001" />
|
||||||
<geom name="pointarrow" type="box" size="0.5 0.1 0.1" pos="0.6 0 0.5" solimp="0.9995 0.9999 0.001" />
|
<geom name="pointarrow" type="box" size="0.5 0.1 0.1" pos="0.6 0 0.5" solimp="0.9 0.99 0.001" />
|
||||||
<joint name="ballx" type="slide" axis="1 0 0" pos="0 0 0" />
|
<joint name="ballx" type="slide" axis="1 0 0" pos="0 0 0" />
|
||||||
<joint name="bally" type="slide" axis="0 1 0" pos="0 0 0" />
|
<joint name="bally" type="slide" axis="0 1 0" pos="0 0 0" />
|
||||||
<joint name="rot" type="hinge" axis="0 0 1" pos="0 0 0" limited="false" />
|
<joint name="rot" type="hinge" axis="0 0 1" pos="0 0 0" limited="false" />
|
||||||
|
@ -70,6 +70,7 @@ class MazeEnv(gym.Env):
|
|||||||
self._observe_blocks = observe_blocks
|
self._observe_blocks = observe_blocks
|
||||||
self._put_spin_near_agent = put_spin_near_agent
|
self._put_spin_near_agent = put_spin_near_agent
|
||||||
self._top_down_view = top_down_view
|
self._top_down_view = top_down_view
|
||||||
|
self._collision_coef = 0.1
|
||||||
|
|
||||||
self._maze_structure = structure = maze_env_utils.construct_maze(
|
self._maze_structure = structure = maze_env_utils.construct_maze(
|
||||||
maze_id=self._maze_id
|
maze_id=self._maze_id
|
||||||
@ -164,7 +165,11 @@ class MazeEnv(gym.Env):
|
|||||||
spinning = maze_env_utils.can_spin(struct)
|
spinning = maze_env_utils.can_spin(struct)
|
||||||
shrink = 0.1 if spinning else 0.99 if falling else 1.0
|
shrink = 0.1 if spinning else 0.99 if falling else 1.0
|
||||||
height_shrink = 0.1 if spinning else 1.0
|
height_shrink = 0.1 if spinning else 1.0
|
||||||
x = j * size_scaling - torso_x + 0.25 * size_scaling if spinning else 0.0
|
x = (
|
||||||
|
j * size_scaling - torso_x + 0.25 * size_scaling
|
||||||
|
if spinning
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
y = i * size_scaling - torso_y
|
y = i * size_scaling - torso_y
|
||||||
h = height / 2 * size_scaling * height_shrink
|
h = height / 2 * size_scaling * height_shrink
|
||||||
size = 0.5 * size_scaling * shrink + self.SIZE_EPS
|
size = 0.5 * size_scaling * shrink + self.SIZE_EPS
|
||||||
@ -530,7 +535,7 @@ class MazeEnv(gym.Env):
|
|||||||
old_pos = self.wrapped_env.get_xy()
|
old_pos = self.wrapped_env.get_xy()
|
||||||
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||||
new_pos = self.wrapped_env.get_xy()
|
new_pos = self.wrapped_env.get_xy()
|
||||||
if self._collision.is_in(old_pos, new_pos):
|
if self._collision.is_in(new_pos):
|
||||||
self.wrapped_env.set_xy(old_pos)
|
self.wrapped_env.set_xy(old_pos)
|
||||||
else:
|
else:
|
||||||
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||||
|
@ -104,7 +104,7 @@ class Collision:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
ARROUND = np.array([[-1, 0], [1, 0], [0, -1], [0, 1]])
|
ARROUND = np.array([[-1, 0], [1, 0], [0, -1], [0, 1]])
|
||||||
OFFSET = {False: 0.5, True: 0.55}
|
OFFSET = {False: 0.48, True: 0.51}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, structure: list, size_scaling: float, torso_x: float, torso_y: float,
|
self, structure: list, size_scaling: float, torso_x: float, torso_y: float,
|
||||||
@ -134,8 +134,8 @@ class Collision:
|
|||||||
max_x = x_base + size_scaling * offset(pos, 3)
|
max_x = x_base + size_scaling * offset(pos, 3)
|
||||||
self.objects.append((min_y, max_y, min_x, max_x))
|
self.objects.append((min_y, max_y, min_x, max_x))
|
||||||
|
|
||||||
def is_in(self, old_pos, new_pos) -> bool:
|
def is_in(self, new_pos) -> bool:
|
||||||
for x, y in (new_pos, (old_pos + new_pos) / 2):
|
x, y = new_pos
|
||||||
for min_y, max_y, min_x, max_x in self.objects:
|
for min_y, max_y, min_x, max_x in self.objects:
|
||||||
if min_x <= x <= max_x and min_y <= y <= max_y:
|
if min_x <= x <= max_x and min_y <= y <= max_y:
|
||||||
return True
|
return True
|
||||||
|
@ -78,7 +78,7 @@ class PointEnv(AgentModel):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def get_xy(self):
|
def get_xy(self):
|
||||||
return self.sim.data.qpos[:2]
|
return np.copy(self.sim.data.qpos[:2])
|
||||||
|
|
||||||
def set_xy(self, xy):
|
def set_xy(self, xy):
|
||||||
qpos = np.copy(self.sim.data.qpos)
|
qpos = np.copy(self.sim.data.qpos)
|
||||||
@ -86,7 +86,7 @@ class PointEnv(AgentModel):
|
|||||||
qpos[1] = xy[1]
|
qpos[1] = xy[1]
|
||||||
|
|
||||||
qvel = self.sim.data.qvel
|
qvel = self.sim.data.qvel
|
||||||
self.set_state(qpos, qvel)
|
self.set_state_without_forward(qpos, qvel)
|
||||||
|
|
||||||
def get_ori(self):
|
def get_ori(self):
|
||||||
return self.sim.data.qpos[self.ORI_IND]
|
return self.sim.data.qpos[self.ORI_IND]
|
||||||
|
Loading…
Reference in New Issue
Block a user