fixed holereacher

This commit is contained in:
ottofabian 2021-07-19 14:04:14 +02:00
parent eae1013dac
commit c6b4cff3a3

View File

@ -123,26 +123,26 @@ class HoleReacherEnv(gym.Env):
def _generate_hole(self): def _generate_hole(self):
if self.initial_width is None: if self.initial_width is None:
width = self.np_random.uniform(0.15, 0.5, 1) width = self.np_random.uniform(0.15, 0.5)
else: else:
width = np.copy(self.initial_width) width = np.copy(self.initial_width)
if self.initial_x is None: if self.initial_x is None:
# sample whole on left or right side # sample whole on left or right side
direction = np.random.choice([-1, 1]) direction = np.random.choice([-1, 1])
# Hole center needs to be half the width away from the arm to give a valid setting. # Hole center needs to be half the width away from the arm to give a valid setting.
x = direction * self.np_random.uniform(width / 2, 3.5, 1) x = direction * self.np_random.uniform(width / 2, 3.5)
else: else:
x = np.copy(self.initial_x) x = np.copy(self.initial_x)
if self.initial_depth is None: if self.initial_depth is None:
# TODO we do not want this right now. # TODO we do not want this right now.
depth = self.np_random.uniform(1, 1, 1) depth = self.np_random.uniform(1, 1)
else: else:
depth = np.copy(self.initial_depth) depth = np.copy(self.initial_depth)
self._tmp_hole_width = width self._tmp_width = width
self._tmp_hole_x = x self._tmp_x = x
self._tmp_hole_depth = depth self._tmp_depth = depth
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth]) self._goal = np.hstack([self._tmp_x, -self._tmp_depth])
def _update_joints(self): def _update_joints(self):
""" """
@ -216,7 +216,6 @@ class HoleReacherEnv(gym.Env):
return np.squeeze(end_effector + self._joints[0, :]) return np.squeeze(end_effector + self._joints[0, :])
def _check_wall_collision(self, line_points): def _check_wall_collision(self, line_points):
# all points that are before the hole in x # all points that are before the hole in x
r, c = np.where(line_points[:, :, 0] < (self._tmp_x - self._tmp_width / 2)) r, c = np.where(line_points[:, :, 0] < (self._tmp_x - self._tmp_width / 2))