diff --git a/alr_envs/classic_control/simple_reacher.py b/alr_envs/classic_control/simple_reacher.py index f564f89..04a1110 100644 --- a/alr_envs/classic_control/simple_reacher.py +++ b/alr_envs/classic_control/simple_reacher.py @@ -6,6 +6,7 @@ from gym import spaces from gym.utils import seeding from mp_env_api.envs.mp_env import MpEnv +from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper class SimpleReacherEnv(MpEnv): @@ -19,7 +20,7 @@ class SimpleReacherEnv(MpEnv): super().__init__() self.link_lengths = np.ones(n_links) self.n_links = n_links - self.dt = 0.1 + self._dt = 0.1 self.random_start = random_start @@ -53,6 +54,10 @@ class SimpleReacherEnv(MpEnv): self._steps = 0 self.seed() + @property + def dt(self) -> Union[float, int]: + return self._dt + def step(self, action: np.ndarray): """ A single step with action in torque space @@ -172,24 +177,6 @@ class SimpleReacherEnv(MpEnv): self.fig.canvas.draw() self.fig.canvas.flush_events() - @property - def active_obs(self): - return np.hstack([ - [self.random_start] * self.n_links, # cos - [self.random_start] * self.n_links, # sin - [self.random_start] * self.n_links, # velocity - [True] * 2, # x-y coordinates of target distance - [False] # env steps - ]) - - @property - def start_pos(self): - return self._start_pos - - @property - def goal_pos(self): - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] @@ -202,24 +189,25 @@ class SimpleReacherEnv(MpEnv): return self._joints[self.n_links].T -if __name__ == '__main__': - nl = 5 - render_mode = "human" # "human" or "partial" or "final" - env = SimpleReacherEnv(n_links=nl) - obs = env.reset() - print("First", obs) +class SimpleReacherMPWrapper(MPEnvWrapper): + @property + def active_obs(self): + return np.hstack([ + [self.env.random_start] * self.env.n_links, # cos + [self.env.random_start] * self.env.n_links, # sin + [self.env.random_start] * self.env.n_links, # velocity + [True] * 2, # x-y coordinates of target distance + [False] # env steps + ]) - for i in range(2000): - # objective.load_result("/tmp/cma") - # test with random actions - ac = 2 * env.action_space.sample() - # ac = np.ones(env.action_space.shape) - obs, rew, d, info = env.step(ac) - env.render(mode=render_mode) + @property + def start_pos(self): + return self._start_pos - print(obs[env.active_obs].shape) + @property + def goal_pos(self): + raise ValueError("Goal position is not available and has to be learnt based on the environment.") - if d or i % 200 == 0: - env.reset() - - env.close() + @property + def dt(self) -> Union[float, int]: + return self.env.dt