From c5109ec2e7c9ed98f807f07bfa6231d5b2f580c9 Mon Sep 17 00:00:00 2001 From: ottofabian Date: Thu, 24 Jun 2021 15:06:25 +0200 Subject: [PATCH] updated hole reacher example to new structure --- alr_envs/classic_control/hole_reacher.py | 30 +++++++++++++++--------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher.py index d5be111..d71c6d1 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher.py @@ -8,6 +8,7 @@ from matplotlib import patches from alr_envs.classic_control.utils import check_self_collision from mp_env_api.envs.mp_env import MpEnv +from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper class HoleReacherEnv(MpEnv): @@ -92,7 +93,7 @@ class HoleReacherEnv(MpEnv): def reset(self): if self.random_start: - # Maybe change more than dirst seed + # Maybe change more than first seed first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4) self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)]) self._start_pos = self._joint_angles.copy() @@ -276,6 +277,21 @@ class HoleReacherEnv(MpEnv): self.fig.gca().add_patch(right_block) self.fig.gca().add_patch(hole_floor) + def seed(self, seed=None): + self.np_random, seed = seeding.np_random(seed) + return [seed] + + @property + def end_effector(self): + return self._joints[self.n_links].T + + def close(self): + super().close() + if self.fig is not None: + plt.close(self.fig) + + +class HoleReacherMPWrapper(MpEnvWrapper): @property def active_obs(self): return np.hstack([ @@ -296,17 +312,9 @@ class HoleReacherEnv(MpEnv): def goal_pos(self) -> Union[float, int, np.ndarray]: raise ValueError("Goal position is not available and has to be learnt based on the environment.") - def seed(self, seed=None): - self.np_random, seed = seeding.np_random(seed) - return [seed] - @property - def end_effector(self): - return self._joints[self.n_links].T - - def close(self): - if self.fig is not None: - plt.close(self.fig) + def dt(self) -> Union[float, int]: + return self.env.dt if __name__ == '__main__':