fixed hole reacher bug

This commit is contained in:
ottofabian 2021-05-18 15:27:08 +02:00
parent e0e4d6d41c
commit 724b8c6c61
3 changed files with 18 additions and 20 deletions

View File

@ -172,6 +172,7 @@ register(
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
"random_start": True,
"allow_self_collision": False, "allow_self_collision": False,
"allow_wall_collision": False, "allow_wall_collision": False,
"hole_width": None, "hole_width": None,

View File

@ -14,7 +14,7 @@ class HoleReacherEnv(MPEnv):
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None, def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False, hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
allow_wall_collision: bool = False, collision_penalty: bool = 1000): allow_wall_collision: bool = False, collision_penalty: float = 1000):
self.n_links = n_links self.n_links = n_links
self.link_lengths = np.ones((n_links, 1)) self.link_lengths = np.ones((n_links, 1))
@ -52,7 +52,7 @@ class HoleReacherEnv(MPEnv):
[np.pi] * self.n_links, # sin [np.pi] * self.n_links, # sin
[np.inf] * self.n_links, # velocity [np.inf] * self.n_links, # velocity
[np.inf], # hole width [np.inf], # hole width
[np.inf], # hole depth # [np.inf], # hole depth
[np.inf] * 2, # x-y coordinates of target distance [np.inf] * 2, # x-y coordinates of target distance
[np.inf] # env steps, because reward start after n steps TODO: Maybe [np.inf] # env steps, because reward start after n steps TODO: Maybe
]) ])
@ -138,24 +138,20 @@ class HoleReacherEnv(MPEnv):
self._is_collided = self_collision or wall_collision self._is_collided = self_collision or wall_collision
def _get_reward(self, acc: np.ndarray): def _get_reward(self, acc: np.ndarray):
success = False reward = 0
reward = -np.inf # success = False
if not self._is_collided:
dist = 0 if self._steps == 199 or self._is_collided:
# return reward only in last time step # return reward only in last time step
if self._steps == 199: # Episode also terminates when colliding, hence return reward
dist = np.linalg.norm(self.end_effector - self._goal)
success = dist < 0.005
else:
# Episode terminates when colliding, hence return reward
dist = np.linalg.norm(self.end_effector - self._goal) dist = np.linalg.norm(self.end_effector - self._goal)
reward = -self.collision_penalty # success = dist < 0.005 and not self._is_collided
reward = - dist ** 2 - self.collision_penalty * self._is_collided
reward -= dist ** 2
reward -= 5e-8 * np.sum(acc ** 2) reward -= 5e-8 * np.sum(acc ** 2)
info = {"is_success": success} # info = {"is_success": success}
return reward, info return reward, {} # info
def _get_obs(self): def _get_obs(self):
theta = self._joint_angles theta = self._joint_angles
@ -164,7 +160,7 @@ class HoleReacherEnv(MPEnv):
np.sin(theta), np.sin(theta),
self._angle_velocity, self._angle_velocity,
self._tmp_hole_width, self._tmp_hole_width,
self._tmp_hole_depth, # self._tmp_hole_depth,
self.end_effector - self._goal, self.end_effector - self._goal,
self._steps self._steps
]) ])
@ -281,7 +277,7 @@ class HoleReacherEnv(MPEnv):
[self.random_start] * self.n_links, # sin [self.random_start] * self.n_links, # sin
[self.random_start] * self.n_links, # velocity [self.random_start] * self.n_links, # velocity
[self._hole_width is None], # hole width [self._hole_width is None], # hole width
[self._hole_depth is None], # hole width # [self._hole_depth is None], # hole depth
[True] * 2, # x-y coordinates of target distance [True] * 2, # x-y coordinates of target distance
[False] # env steps [False] # env steps
]) ])

View File

@ -2,16 +2,17 @@ def ccw(A, B, C):
return (C[1] - A[1]) * (B[0] - A[0]) - (B[1] - A[1]) * (C[0] - A[0]) > 1e-12 return (C[1] - A[1]) * (B[0] - A[0]) - (B[1] - A[1]) * (C[0] - A[0]) > 1e-12
# Return true if line segments AB and CD intersect
def intersect(A, B, C, D): def intersect(A, B, C, D):
"""
Checks whether line segments AB and CD intersect
"""
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D) return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
def check_self_collision(line_points): def check_self_collision(line_points):
"Checks whether line segments and intersect"
for i, line1 in enumerate(line_points): for i, line1 in enumerate(line_points):
for line2 in line_points[i + 2:, :, :]: for line2 in line_points[i + 2:, :, :]:
# if line1 != line2:
if intersect(line1[0], line1[-1], line2[0], line2[-1]): if intersect(line1[0], line1[-1], line2[0], line2[-1]):
return True return True
return False return False