fixed holereacher

This commit is contained in:
ottofabian 2021-07-19 11:57:06 +02:00
parent 618d333116
commit eae1013dac

View File

@ -122,13 +122,27 @@ class HoleReacherEnv(gym.Env):
return self._get_obs().copy() return self._get_obs().copy()
def _generate_hole(self): def _generate_hole(self):
self._tmp_x = self.np_random.uniform(1, 3.5, 1) if self.initial_x is None else np.copy(self.initial_x) if self.initial_width is None:
self._tmp_width = self.np_random.uniform(0.15, 0.5, 1) if self.initial_width is None else np.copy( width = self.np_random.uniform(0.15, 0.5, 1)
self.initial_width) else:
width = np.copy(self.initial_width)
if self.initial_x is None:
# sample whole on left or right side
direction = np.random.choice([-1, 1])
# Hole center needs to be half the width away from the arm to give a valid setting.
x = direction * self.np_random.uniform(width / 2, 3.5, 1)
else:
x = np.copy(self.initial_x)
if self.initial_depth is None:
# TODO we do not want this right now. # TODO we do not want this right now.
self._tmp_depth = self.np_random.uniform(1, 1, 1) if self.initial_depth is None else np.copy( depth = self.np_random.uniform(1, 1, 1)
self.initial_depth) else:
self._goal = np.hstack([self._tmp_x, -self._tmp_depth]) depth = np.copy(self.initial_depth)
self._tmp_hole_width = width
self._tmp_hole_x = x
self._tmp_hole_depth = depth
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
def _update_joints(self): def _update_joints(self):
""" """