Fix intersection
This commit is contained in:
parent
c3ab9c9545
commit
a67db885a2
@ -31,8 +31,7 @@ class MazeEnv(gym.Env):
|
|||||||
maze_height: float = 0.5,
|
maze_height: float = 0.5,
|
||||||
maze_size_scaling: float = 4.0,
|
maze_size_scaling: float = 4.0,
|
||||||
inner_reward_scaling: float = 1.0,
|
inner_reward_scaling: float = 1.0,
|
||||||
restitution_coef: float = 0.6,
|
restitution_coef: float = 0.9,
|
||||||
collision_penalty: float = 0.001,
|
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -50,7 +49,6 @@ class MazeEnv(gym.Env):
|
|||||||
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
|
self._put_spin_near_agent = self._task.PUT_SPIN_NEAR_AGENT
|
||||||
self._top_down_view = top_down_view
|
self._top_down_view = top_down_view
|
||||||
self._restitution_coef = restitution_coef
|
self._restitution_coef = restitution_coef
|
||||||
self._collision_penalty = collision_penalty
|
|
||||||
|
|
||||||
self._maze_structure = structure = self._task.create_maze()
|
self._maze_structure = structure = self._task.create_maze()
|
||||||
# Elevate the maze to allow for falling.
|
# Elevate the maze to allow for falling.
|
||||||
@ -437,12 +435,16 @@ class MazeEnv(gym.Env):
|
|||||||
old_pos = self.wrapped_env.get_xy()
|
old_pos = self.wrapped_env.get_xy()
|
||||||
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||||
new_pos = self.wrapped_env.get_xy()
|
new_pos = self.wrapped_env.get_xy()
|
||||||
|
# Checks that new_position is in the wall
|
||||||
intersection = self._collision.detect_intersection(old_pos, new_pos)
|
intersection = self._collision.detect_intersection(old_pos, new_pos)
|
||||||
if intersection is not None:
|
if intersection is not None:
|
||||||
rest_vec = -self._restitution_coef * (new_pos - intersection)
|
pos = intersection + (intersection - new_pos) * self._restitution_coef
|
||||||
pos = old_pos + rest_vec
|
# Checks that pos is in the wall
|
||||||
|
intersection2 = self._collision.detect_intersection(old_pos, pos)
|
||||||
|
if intersection2 is not None:
|
||||||
|
# If pos is not in the wall, we give up computing the position
|
||||||
|
pos = old_pos
|
||||||
self.wrapped_env.set_collision(pos, self._restitution_coef)
|
self.wrapped_env.set_collision(pos, self._restitution_coef)
|
||||||
inner_reward -= self._collision_penalty
|
|
||||||
else:
|
else:
|
||||||
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
inner_next_obs, inner_reward, _, info = self.wrapped_env.step(action)
|
||||||
next_obs = self._get_obs()
|
next_obs = self._get_obs()
|
||||||
|
@ -81,7 +81,7 @@ class PointEnv(AgentModel):
|
|||||||
qpos = self.sim.data.qpos.copy()
|
qpos = self.sim.data.qpos.copy()
|
||||||
qpos[:2] = xy
|
qpos[:2] = xy
|
||||||
qvel = self.sim.data.qvel.copy()
|
qvel = self.sim.data.qvel.copy()
|
||||||
qvel[:2] = -restitution_coef * qvel[:2]
|
qvel[:2] *= -restitution_coef
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
|
|
||||||
def get_ori(self):
|
def get_ori(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user