From 6e06e11cfa90487f7cb2de1cda20ed7c029d3210 Mon Sep 17 00:00:00 2001 From: Onur Date: Wed, 29 Jun 2022 10:39:28 +0200 Subject: [PATCH] added new mp wrappers to all environments --- alr_envs/alr/__init__.py | 2 +- .../hole_reacher/new_mp_wrapper.py | 31 ++++++++++++++++++ .../simple_reacher/new_mp_wrapper.py | 31 ++++++++++++++++++ .../viapoint_reacher/new_mp_wrapper.py | 32 +++++++++++++++++++ .../half_cheetah_jump/new_mp_wrapper.py | 24 ++++++++++++++ .../alr/mujoco/hopper_throw/mp_wrapper.py | 2 +- .../alr/mujoco/hopper_throw/new_mp_wrapper.py | 27 ++++++++++++++++ alr_envs/alr/mujoco/reacher/new_mp_wrapper.py | 8 +++-- .../mujoco/walker_2d_jump/new_mp_wrapper.py | 28 ++++++++++++++++ 9 files changed, 180 insertions(+), 5 deletions(-) create mode 100644 alr_envs/alr/classic_control/hole_reacher/new_mp_wrapper.py create mode 100644 alr_envs/alr/classic_control/simple_reacher/new_mp_wrapper.py create mode 100644 alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py create mode 100644 alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py create mode 100644 alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py create mode 100644 alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index d8169a8..4c90512 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -537,7 +537,7 @@ for _v in _versions: register( id=_env_id, entry_point='alr_envs.utils.make_env_helpers:make_mp_env_helper', - kwargs=kwargs_dict_bp_promp_fixed_release + kwargs=kwargs_dict_bp_promp ) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ######################################################################################################################## diff --git a/alr_envs/alr/classic_control/hole_reacher/new_mp_wrapper.py b/alr_envs/alr/classic_control/hole_reacher/new_mp_wrapper.py new file mode 100644 index 0000000..1f1d198 --- /dev/null +++ b/alr_envs/alr/classic_control/hole_reacher/new_mp_wrapper.py @@ -0,0 +1,31 @@ +from typing import Tuple, Union + +import numpy as np + +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper + + +class NewMPWrapper(RawInterfaceWrapper): + + def get_context_mask(self): + return np.hstack([ + [self.env.random_start] * self.env.n_links, # cos + [self.env.random_start] * self.env.n_links, # sin + [self.env.random_start] * self.env.n_links, # velocity + [self.env.initial_width is None], # hole width + # [self.env.hole_depth is None], # hole depth + [True] * 2, # x-y coordinates of target distance + [False] # env steps + ]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.current_pos + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.current_vel + + @property + def dt(self) -> Union[float, int]: + return self.env.dt diff --git a/alr_envs/alr/classic_control/simple_reacher/new_mp_wrapper.py b/alr_envs/alr/classic_control/simple_reacher/new_mp_wrapper.py new file mode 100644 index 0000000..c1497e6 --- /dev/null +++ b/alr_envs/alr/classic_control/simple_reacher/new_mp_wrapper.py @@ -0,0 +1,31 @@ +from typing import Tuple, Union + +import numpy as np + +from mp_env_api import MPEnvWrapper + +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper + + +class MPWrapper(RawInterfaceWrapper): + + def context_mask(self): + return np.hstack([ + [self.env.random_start] * self.env.n_links, # cos + [self.env.random_start] * self.env.n_links, # sin + [self.env.random_start] * self.env.n_links, # velocity + [True] * 2, # x-y coordinates of target distance + [False] # env steps + ]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.current_pos + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.current_vel + + @property + def dt(self) -> Union[float, int]: + return self.env.dt diff --git a/alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py b/alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py new file mode 100644 index 0000000..f02dfe1 --- /dev/null +++ b/alr_envs/alr/classic_control/viapoint_reacher/new_mp_wrapper.py @@ -0,0 +1,32 @@ +from typing import Tuple, Union + +import numpy as np + +from mp_env_api import MPEnvWrapper + +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper + + +class MPWrapper(RawInterfaceWrapper): + + def context_mask(self): + return np.hstack([ + [self.env.random_start] * self.env.n_links, # cos + [self.env.random_start] * self.env.n_links, # sin + [self.env.random_start] * self.env.n_links, # velocity + [self.env.initial_via_target is None] * 2, # x-y coordinates of via point distance + [True] * 2, # x-y coordinates of target distance + [False] # env steps + ]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.current_pos + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.current_vel + + @property + def dt(self) -> Union[float, int]: + return self.env.dt diff --git a/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py new file mode 100644 index 0000000..f098c2d --- /dev/null +++ b/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py @@ -0,0 +1,24 @@ +from typing import Tuple, Union + +import numpy as np + +from mp_env_api import MPEnvWrapper + +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper + + +class MPWrapper(RawInterfaceWrapper): + def context_mask(self): + return np.hstack([ + [False] * 17, + [True] # goal height + ]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray]: + return self.env.sim.data.qpos[3:9].copy() + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.sim.data.qvel[3:9].copy() + diff --git a/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py b/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py index e5e9486..909e00a 100644 --- a/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py +++ b/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py @@ -10,7 +10,7 @@ class MPWrapper(MPEnvWrapper): def active_obs(self): return np.hstack([ [False] * 17, - [True] # goal pos + [True] # goal pos ]) @property diff --git a/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py b/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py new file mode 100644 index 0000000..049c2f0 --- /dev/null +++ b/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py @@ -0,0 +1,27 @@ +from typing import Tuple, Union + +import numpy as np + +from mp_env_api import MPEnvWrapper + +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper + + +class MPWrapper(RawInterfaceWrapper): + def context_mask(self): + return np.hstack([ + [False] * 17, + [True] # goal pos + ]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray]: + return self.env.sim.data.qpos[3:6].copy() + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.sim.data.qvel[3:6].copy() + + @property + def dt(self) -> Union[float, int]: + return self.env.dt diff --git a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py index 8df365a..54910e5 100644 --- a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py @@ -2,8 +2,10 @@ from alr_envs.mp.black_box_wrapper import BlackBoxWrapper from typing import Union, Tuple import numpy as np +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class MPWrapper(BlackBoxWrapper): + +class MPWrapper(RawInterfaceWrapper): @property def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: @@ -12,7 +14,7 @@ class MPWrapper(BlackBoxWrapper): def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.sim.data.qvel.flat[:self.env.n_links] - def get_context_mask(self): + def context_mask(self): return np.concatenate([ [False] * self.env.n_links, # cos [False] * self.env.n_links, # sin @@ -21,4 +23,4 @@ class MPWrapper(BlackBoxWrapper): [False] * 3, # goal distance # self.get_body_com("target"), # only return target to make problem harder [False], # step - ]) \ No newline at end of file + ]) diff --git a/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py new file mode 100644 index 0000000..dde928f --- /dev/null +++ b/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py @@ -0,0 +1,28 @@ +from typing import Tuple, Union + +import numpy as np + +from mp_env_api import MPEnvWrapper + +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper + + +class MPWrapper(RawInterfaceWrapper): + def context_mask(self): + return np.hstack([ + [False] * 17, + [True] # goal pos + ]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray]: + return self.env.sim.data.qpos[3:9].copy() + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.sim.data.qvel[3:9].copy() + + @property + def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]: + raise ValueError("Goal position is not available and has to be learnt based on the environment.") +