From fa7dfdc0813ec12e31883fd7d18270ac9469af9e Mon Sep 17 00:00:00 2001 From: ottofabian Date: Thu, 24 Jun 2021 15:19:05 +0200 Subject: [PATCH] updated hole reacher example to new structure --- alr_envs/classic_control/hole_reacher.py | 55 +++++++----------------- 1 file changed, 16 insertions(+), 39 deletions(-) diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher.py index d71c6d1..9100686 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher.py @@ -8,10 +8,10 @@ from matplotlib import patches from alr_envs.classic_control.utils import check_self_collision from mp_env_api.envs.mp_env import MpEnv -from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper +from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper -class HoleReacherEnv(MpEnv): +class HoleReacherEnv(gym.Env): def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None, hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False, @@ -23,9 +23,9 @@ class HoleReacherEnv(MpEnv): self.random_start = random_start # provided initial parameters - self._hole_x = hole_x # x-position of center of hole - self._hole_width = hole_width # width of hole - self._hole_depth = hole_depth # depth of hole + self.hole_x = hole_x # x-position of center of hole + self.hole_width = hole_width # width of hole + self.hole_depth = hole_depth # depth of hole # temp container for current env state self._tmp_hole_x = None @@ -112,12 +112,12 @@ class HoleReacherEnv(MpEnv): return self._get_obs().copy() def _generate_hole(self): - self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self._hole_x is None else np.copy(self._hole_x) - self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self._hole_width is None else np.copy( - self._hole_width) + self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self.hole_x is None else np.copy(self.hole_x) + self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self.hole_width is None else np.copy( + self.hole_width) # TODO we do not want this right now. - self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self._hole_depth is None else np.copy( - self._hole_depth) + self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self.hole_depth is None else np.copy( + self.hole_depth) self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth]) def _update_joints(self): @@ -291,15 +291,15 @@ class HoleReacherEnv(MpEnv): plt.close(self.fig) -class HoleReacherMPWrapper(MpEnvWrapper): +class HoleReacherMPWrapper(MPEnvWrapper): @property def active_obs(self): return np.hstack([ - [self.random_start] * self.n_links, # cos - [self.random_start] * self.n_links, # sin - [self.random_start] * self.n_links, # velocity - [self._hole_width is None], # hole width - # [self._hole_depth is None], # hole depth + [self.env.random_start] * self.env.n_links, # cos + [self.env.random_start] * self.env.n_links, # sin + [self.env.random_start] * self.env.n_links, # velocity + [self.env.hole_width is None], # hole width + # [self.env.hole_depth is None], # hole depth [True] * 2, # x-y coordinates of target distance [False] # env steps ]) @@ -315,26 +315,3 @@ class HoleReacherMPWrapper(MpEnvWrapper): @property def dt(self) -> Union[float, int]: return self.env.dt - - -if __name__ == '__main__': - nl = 5 - render_mode = "human" # "human" or "partial" or "final" - env = HoleReacherEnv(n_links=nl, allow_self_collision=False, allow_wall_collision=False, hole_width=None, - hole_depth=1, hole_x=None) - obs = env.reset() - - for i in range(2000): - # objective.load_result("/tmp/cma") - # test with random actions - ac = 2 * env.action_space.sample() - obs, rew, d, info = env.step(ac) - if i % 10 == 0: - env.render(mode=render_mode) - - print(rew) - - if d: - env.reset() - - env.close()