fixed holereacher
This commit is contained in:
parent
eae1013dac
commit
c6b4cff3a3
@ -123,26 +123,26 @@ class HoleReacherEnv(gym.Env):
|
|||||||
|
|
||||||
def _generate_hole(self):
|
def _generate_hole(self):
|
||||||
if self.initial_width is None:
|
if self.initial_width is None:
|
||||||
width = self.np_random.uniform(0.15, 0.5, 1)
|
width = self.np_random.uniform(0.15, 0.5)
|
||||||
else:
|
else:
|
||||||
width = np.copy(self.initial_width)
|
width = np.copy(self.initial_width)
|
||||||
if self.initial_x is None:
|
if self.initial_x is None:
|
||||||
# sample whole on left or right side
|
# sample whole on left or right side
|
||||||
direction = np.random.choice([-1, 1])
|
direction = np.random.choice([-1, 1])
|
||||||
# Hole center needs to be half the width away from the arm to give a valid setting.
|
# Hole center needs to be half the width away from the arm to give a valid setting.
|
||||||
x = direction * self.np_random.uniform(width / 2, 3.5, 1)
|
x = direction * self.np_random.uniform(width / 2, 3.5)
|
||||||
else:
|
else:
|
||||||
x = np.copy(self.initial_x)
|
x = np.copy(self.initial_x)
|
||||||
if self.initial_depth is None:
|
if self.initial_depth is None:
|
||||||
# TODO we do not want this right now.
|
# TODO we do not want this right now.
|
||||||
depth = self.np_random.uniform(1, 1, 1)
|
depth = self.np_random.uniform(1, 1)
|
||||||
else:
|
else:
|
||||||
depth = np.copy(self.initial_depth)
|
depth = np.copy(self.initial_depth)
|
||||||
|
|
||||||
self._tmp_hole_width = width
|
self._tmp_width = width
|
||||||
self._tmp_hole_x = x
|
self._tmp_x = x
|
||||||
self._tmp_hole_depth = depth
|
self._tmp_depth = depth
|
||||||
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
|
self._goal = np.hstack([self._tmp_x, -self._tmp_depth])
|
||||||
|
|
||||||
def _update_joints(self):
|
def _update_joints(self):
|
||||||
"""
|
"""
|
||||||
@ -216,7 +216,6 @@ class HoleReacherEnv(gym.Env):
|
|||||||
return np.squeeze(end_effector + self._joints[0, :])
|
return np.squeeze(end_effector + self._joints[0, :])
|
||||||
|
|
||||||
def _check_wall_collision(self, line_points):
|
def _check_wall_collision(self, line_points):
|
||||||
|
|
||||||
# all points that are before the hole in x
|
# all points that are before the hole in x
|
||||||
r, c = np.where(line_points[:, :, 0] < (self._tmp_x - self._tmp_width / 2))
|
r, c = np.where(line_points[:, :, 0] < (self._tmp_x - self._tmp_width / 2))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user