From 4c334b1129a4d735cdedd16bd57765200fd264d0 Mon Sep 17 00:00:00 2001 From: ottofabian Date: Thu, 24 Jun 2021 11:39:26 +0200 Subject: [PATCH] Hole Reacher extended to have holes in both directions --- alr_envs/classic_control/hole_reacher.py | 34 +++++++++++++++++------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher.py index 730e7bf..ce6d9c3 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher.py @@ -107,12 +107,29 @@ class HoleReacherEnv(AlrEnv): return self._get_obs().copy() def _generate_hole(self): - self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self._hole_x is None else np.copy(self._hole_x) - self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self._hole_width is None else np.copy( - self._hole_width) - # TODO we do not want this right now. - self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self._hole_depth is None else np.copy( - self._hole_depth) + if self._hole_width is None: + width = self.np_random.uniform(0.15, 0.5, 1) + else: + width = np.copy(self._hole_width) + + if self._hole_x is None: + # sample whole on left or right side + direction = np.random.choice([-1, 1]) + # Hole center needs to be half the width away from the arm to give a valid setting. + x = direction * self.np_random.uniform(width / 2, 3.5, 1) + else: + x = np.copy(self._hole_x) + + if self._hole_depth is None: + # TODO we do not want this right now. + depth = self.np_random.uniform(1, 1, 1) + else: + depth = np.copy(self._hole_depth) + + self._tmp_hole_width = width + self._tmp_hole_x = x + self._tmp_hole_depth = depth + self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth]) def _update_joints(self): @@ -315,13 +332,10 @@ if __name__ == '__main__': for i in range(2000): # objective.load_result("/tmp/cma") # test with random actions - ac = 2 * env.action_space.sample() + ac = env.action_space.sample() obs, rew, d, info = env.step(ac) if i % 10 == 0: env.render(mode=render_mode) - - print(rew) - if d: env.reset()