fixed holereacher

This commit is contained in:
ottofabian 2021-07-19 14:04:14 +02:00
parent eae1013dac
commit c6b4cff3a3

View File

@ -123,26 +123,26 @@ class HoleReacherEnv(gym.Env):
def _generate_hole(self):
if self.initial_width is None:
width = self.np_random.uniform(0.15, 0.5, 1)
width = self.np_random.uniform(0.15, 0.5)
else:
width = np.copy(self.initial_width)
if self.initial_x is None:
# sample whole on left or right side
direction = np.random.choice([-1, 1])
# Hole center needs to be half the width away from the arm to give a valid setting.
x = direction * self.np_random.uniform(width / 2, 3.5, 1)
x = direction * self.np_random.uniform(width / 2, 3.5)
else:
x = np.copy(self.initial_x)
if self.initial_depth is None:
# TODO we do not want this right now.
depth = self.np_random.uniform(1, 1, 1)
depth = self.np_random.uniform(1, 1)
else:
depth = np.copy(self.initial_depth)
self._tmp_hole_width = width
self._tmp_hole_x = x
self._tmp_hole_depth = depth
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
self._tmp_width = width
self._tmp_x = x
self._tmp_depth = depth
self._goal = np.hstack([self._tmp_x, -self._tmp_depth])
def _update_joints(self):
"""
@ -216,7 +216,6 @@ class HoleReacherEnv(gym.Env):
return np.squeeze(end_effector + self._joints[0, :])
def _check_wall_collision(self, line_points):
# all points that are before the hole in x
r, c = np.where(line_points[:, :, 0] < (self._tmp_x - self._tmp_width / 2))