Hole Reacher extended to have holes in both directions

This commit is contained in:
ottofabian 2021-06-24 11:39:26 +02:00
parent 6b0dfd7c24
commit 4c334b1129

View File

@ -107,12 +107,29 @@ class HoleReacherEnv(AlrEnv):
return self._get_obs().copy()
def _generate_hole(self):
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self._hole_x is None else np.copy(self._hole_x)
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self._hole_width is None else np.copy(
self._hole_width)
# TODO we do not want this right now.
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self._hole_depth is None else np.copy(
self._hole_depth)
if self._hole_width is None:
width = self.np_random.uniform(0.15, 0.5, 1)
else:
width = np.copy(self._hole_width)
if self._hole_x is None:
# sample whole on left or right side
direction = np.random.choice([-1, 1])
# Hole center needs to be half the width away from the arm to give a valid setting.
x = direction * self.np_random.uniform(width / 2, 3.5, 1)
else:
x = np.copy(self._hole_x)
if self._hole_depth is None:
# TODO we do not want this right now.
depth = self.np_random.uniform(1, 1, 1)
else:
depth = np.copy(self._hole_depth)
self._tmp_hole_width = width
self._tmp_hole_x = x
self._tmp_hole_depth = depth
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
def _update_joints(self):
@ -315,13 +332,10 @@ if __name__ == '__main__':
for i in range(2000):
# objective.load_result("/tmp/cma")
# test with random actions
ac = 2 * env.action_space.sample()
ac = env.action_space.sample()
obs, rew, d, info = env.step(ac)
if i % 10 == 0:
env.render(mode=render_mode)
print(rew)
if d:
env.reset()