diff --git a/alr_envs/alr/mujoco/beerpong/mp_wrapper.py b/alr_envs/alr/mujoco/beerpong/mp_wrapper.py new file mode 100644 index 0000000..e69d4f9 --- /dev/null +++ b/alr_envs/alr/mujoco/beerpong/mp_wrapper.py @@ -0,0 +1,42 @@ +from typing import Union, Tuple + +import numpy as np + +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper + + +class MPWrapper(RawInterfaceWrapper): + + def get_context_mask(self): + return np.hstack([ + [False] * 7, # cos + [False] * 7, # sin + [False] * 7, # joint velocities + [False] * 3, # cup_goal_diff_final + [False] * 3, # cup_goal_diff_top + [True] * 2, # xy position of cup + [False] # env steps + ]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.sim.data.qpos[0:7].copy() + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.sim.data.qvel[0:7].copy() + + # TODO: Fix this + def _episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]: + if mp.learn_tau: + self.env.env.release_step = action[0] / self.env.dt # Tau value + return action, None + else: + return action, None + + def set_context(self, context): + xyz = np.zeros(3) + xyz[:2] = context + xyz[-1] = 0.840 + self.env.env.model.body_pos[self.env.env.cup_table_id] = xyz + return self.get_observation_from_step(self.env.env._get_obs()) diff --git a/setup.py b/setup.py index 055ac81..e99b393 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup # Environment-specific dependencies for dmc and metaworld extras = { "dmc": ["dm_control"], - "meta": ["mujoco_py<2.2,>=2.1, git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld"], + "meta": ["metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld"], "mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"], } @@ -20,13 +20,7 @@ setup( packages=['alr_envs', 'alr_envs.alr', 'alr_envs.open_ai', 'alr_envs.dmc', 'alr_envs.meta', 'alr_envs.utils'], install_requires=[ 'gym', - 'PyQt5', - # 'matplotlib', - # 'mp_env_api @ git+https://github.com/ALRhub/motion_primitive_env_api.git', - # 'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git', - 'mujoco-py<2.1,>=2.0', - 'dm_control', - 'metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld', + "mujoco_py<2.2,>=2.1", ], url='https://github.com/ALRhub/alr_envs/', # license='AGPL-3.0 license',