diff --git a/alr_envs/classic_control/viapoint_reacher.py b/alr_envs/classic_control/viapoint_reacher.py index 15934a8..2965df4 100644 --- a/alr_envs/classic_control/viapoint_reacher.py +++ b/alr_envs/classic_control/viapoint_reacher.py @@ -7,9 +7,10 @@ from gym.utils import seeding from alr_envs.classic_control.utils import check_self_collision from mp_env_api.envs.mp_env import MpEnv +from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper -class ViaPointReacher(MpEnv): +class ViaPointReacher(gym.Env): def __init__(self, n_links, random_start: bool = True, via_target: Union[None, Iterable] = None, target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000): @@ -20,8 +21,8 @@ class ViaPointReacher(MpEnv): self.random_start = random_start # provided initial parameters - self._target = target # provided target value - self._via_target = via_target # provided via point target value + self.target = target # provided target value + self.via_target = via_target # provided via point target value # temp container for current env state self._via_point = np.ones(2) @@ -39,7 +40,7 @@ class ViaPointReacher(MpEnv): self._start_vel = np.zeros(self.n_links) self.weight_matrix_scale = 1 - self.dt = 0.01 + self._dt = 0.01 action_bound = np.pi * np.ones((self.n_links,)) state_bound = np.hstack([ @@ -60,6 +61,10 @@ class ViaPointReacher(MpEnv): self._steps = 0 self.seed() + @property + def dt(self): + return self._dt + def step(self, action: np.ndarray): """ a single step with an action in joint velocity space @@ -104,22 +109,22 @@ class ViaPointReacher(MpEnv): total_length = np.sum(self.link_lengths) # rejection sampled point in inner circle with 0.5*Radius - if self._via_target is None: + if self.via_target is None: via_target = np.array([total_length, total_length]) while np.linalg.norm(via_target) >= 0.5 * total_length: via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2) else: - via_target = np.copy(self._via_target) + via_target = np.copy(self.via_target) # rejection sampled point in outer circle - if self._target is None: + if self.target is None: goal = np.array([total_length, total_length]) while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length: goal = self.np_random.uniform(low=-total_length, high=total_length, size=2) else: - goal = np.copy(self._target) + goal = np.copy(self.target) - self._via_target = via_target + self.via_target = via_target self._goal = goal def _update_joints(self): @@ -266,25 +271,6 @@ class ViaPointReacher(MpEnv): plt.pause(0.01) - @property - def active_obs(self): - return np.hstack([ - [self.random_start] * self.n_links, # cos - [self.random_start] * self.n_links, # sin - [self.random_start] * self.n_links, # velocity - [self._via_target is None] * 2, # x-y coordinates of via point distance - [True] * 2, # x-y coordinates of target distance - [False] # env steps - ]) - - @property - def start_pos(self) -> Union[float, int, np.ndarray]: - return self._start_pos - - @property - def goal_pos(self) -> Union[float, int, np.ndarray]: - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - def seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] @@ -298,24 +284,25 @@ class ViaPointReacher(MpEnv): plt.close(self.fig) -if __name__ == '__main__': - nl = 5 - render_mode = "human" # "human" or "partial" or "final" - env = ViaPointReacher(n_links=nl, allow_self_collision=False) - env.reset() - env.render(mode=render_mode) +class ViaPointReacherMPWrapper(MPEnvWrapper): + @property + def active_obs(self): + return np.hstack([ + [self.env.random_start] * self.env.n_links, # cos + [self.env.random_start] * self.env.n_links, # sin + [self.env.random_start] * self.env.n_links, # velocity + [self.env.via_target is None] * 2, # x-y coordinates of via point distance + [True] * 2, # x-y coordinates of target distance + [False] # env steps + ]) - for i in range(300): - # objective.load_result("/tmp/cma") - # test with random actions - ac = env.action_space.sample() - # ac[0] += np.pi/2 - obs, rew, d, info = env.step(ac) - env.render(mode=render_mode) + @property + def start_pos(self) -> Union[float, int, np.ndarray]: + return self._start_pos - print(rew) + @property + def goal_pos(self) -> Union[float, int, np.ndarray]: + raise ValueError("Goal position is not available and has to be learnt based on the environment.") - if d: - break - - env.close() + def dt(self) -> Union[float, int]: + return self.env.dt