diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 607ef18..435cfdb 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -25,21 +25,13 @@ ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []} DEFAULT_MP_ENV_DICT = { "name": 'EnvName', "wrappers": [], - # TODO move scale to traj_gen - "ep_wrapper_kwargs": { - "weight_scale": 1 + "traj_gen_kwargs": { + "weight_scale": 1, + 'movement_primitives_type': 'promp' }, - # TODO traj_gen_kwargs - # TODO remove action_dim - "movement_primitives_kwargs": { - 'movement_primitives_type': 'promp', - 'action_dim': 7 - }, - # TODO remove tau "phase_generator_kwargs": { 'phase_generator_type': 'linear', 'delay': 0, - 'tau': 1.5, # initial value 'learn_tau': False, 'learn_delay': False }, @@ -51,7 +43,7 @@ DEFAULT_MP_ENV_DICT = { "basis_generator_kwargs": { 'basis_generator_type': 'zero_rbf', 'num_basis': 5, - 'num_basis_zero_start': 2 # TODO: Change to 1 + 'num_basis_zero_start': 1 } } @@ -370,9 +362,7 @@ for _v in _versions: _env_id = f'{_name[0]}ProMP-{_name[1]}' kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) - kwargs_dict_simple_reacher_promp['wrappers'].append('TODO') # TODO - kwargs_dict_simple_reacher_promp['movement_primitives_kwargs']['action_dim'] = 2 if "long" not in _v.lower() else 5 - kwargs_dict_simple_reacher_promp['phase_generator_kwargs']['tau'] = 2 + kwargs_dict_simple_reacher_promp['wrappers'].append(classic_control.simple_reacher.MPWrapper) kwargs_dict_simple_reacher_promp['controller_kwargs']['p_gains'] = 0.6 kwargs_dict_simple_reacher_promp['controller_kwargs']['d_gains'] = 0.075 kwargs_dict_simple_reacher_promp['name'] = _env_id @@ -405,7 +395,7 @@ register( ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0") kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) -kwargs_dict_via_point_reacher_promp['wrappers'].append('TODO') # TODO +kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper) kwargs_dict_via_point_reacher_promp['controller_kwargs']['controller_type'] = 'velocity' kwargs_dict_via_point_reacher_promp['name'] = "ViaPointReacherProMP-v0" register( @@ -444,12 +434,9 @@ for _v in _versions: _env_id = f'{_name[0]}ProMP-{_name[1]}' kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) - kwargs_dict_hole_reacher_promp['wrappers'].append('TODO') # TODO - kwargs_dict_hole_reacher_promp['ep_wrapper_kwargs']['weight_scale'] = 2 - # kwargs_dict_hole_reacher_promp['movement_primitives_kwargs']['action_dim'] = 5 - # kwargs_dict_hole_reacher_promp['phase_generator_kwargs']['tau'] = 2 + kwargs_dict_hole_reacher_promp['wrappers'].append(classic_control.hole_reacher.MPWrapper) + kwargs_dict_hole_reacher_promp['traj_gen_kwargs']['weight_scale'] = 2 kwargs_dict_hole_reacher_promp['controller_kwargs']['controller_type'] = 'velocity' - # kwargs_dict_hole_reacher_promp['basis_generator_kwargs']['num_basis'] = 5 kwargs_dict_hole_reacher_promp['name'] = f"alr_envs:{_v}" register( id=_env_id, @@ -489,9 +476,7 @@ for _v in _versions: _env_id = f'{_name[0]}ProMP-{_name[1]}' kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) - kwargs_dict_alr_reacher_promp['wrappers'].append('TODO') # TODO - kwargs_dict_alr_reacher_promp['movement_primitives_kwargs']['action_dim'] = 5 if "long" not in _v.lower() else 7 - kwargs_dict_alr_reacher_promp['phase_generator_kwargs']['tau'] = 4 + kwargs_dict_alr_reacher_promp['wrappers'].append(mujoco.reacher.MPWrapper) kwargs_dict_alr_reacher_promp['controller_kwargs']['p_gains'] = 1 kwargs_dict_alr_reacher_promp['controller_kwargs']['d_gains'] = 0.1 kwargs_dict_alr_reacher_promp['name'] = f"alr_envs:{_v}" @@ -509,13 +494,12 @@ for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' kwargs_dict_bp_promp = deepcopy(DEFAULT_MP_ENV_DICT) - kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.NewMPWrapper) - kwargs_dict_bp_promp['movement_primitives_kwargs']['action_dim'] = 7 - kwargs_dict_bp_promp['phase_generator_kwargs']['tau'] = 0.8 + kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.MPWrapper) kwargs_dict_bp_promp['phase_generator_kwargs']['learn_tau'] = True kwargs_dict_bp_promp['controller_kwargs']['p_gains'] = np.array([1.5, 5, 2.55, 3, 2., 2, 1.25]) kwargs_dict_bp_promp['controller_kwargs']['d_gains'] = np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125]) kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2 + kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2 kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}" register( id=_env_id, @@ -530,12 +514,12 @@ for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' kwargs_dict_bp_promp = deepcopy(DEFAULT_MP_ENV_DICT) - kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.NewMPWrapper) - kwargs_dict_bp_promp['movement_primitives_kwargs']['action_dim'] = 7 + kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.MPWrapper) kwargs_dict_bp_promp['phase_generator_kwargs']['tau'] = 0.62 kwargs_dict_bp_promp['controller_kwargs']['p_gains'] = np.array([1.5, 5, 2.55, 3, 2., 2, 1.25]) kwargs_dict_bp_promp['controller_kwargs']['d_gains'] = np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125]) kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2 + kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2 kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}" register( id=_env_id, @@ -555,11 +539,7 @@ for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) - kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.NewMPWrapper) - kwargs_dict_ant_jump_promp['movement_primitives_kwargs']['action_dim'] = 8 - kwargs_dict_ant_jump_promp['phase_generator_kwargs']['tau'] = 10 - kwargs_dict_ant_jump_promp['controller_kwargs']['p_gains'] = np.ones(8) - kwargs_dict_ant_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(8) + kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper) kwargs_dict_ant_jump_promp['name'] = f"alr_envs:{_v}" register( id=_env_id, @@ -576,11 +556,7 @@ for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) - kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.ant_jump.NewMPWrapper) - kwargs_dict_halfcheetah_jump_promp['movement_primitives_kwargs']['action_dim'] = 6 - kwargs_dict_halfcheetah_jump_promp['phase_generator_kwargs']['tau'] = 5 - kwargs_dict_halfcheetah_jump_promp['controller_kwargs']['p_gains'] = np.ones(6) - kwargs_dict_halfcheetah_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(6) + kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper) kwargs_dict_halfcheetah_jump_promp['name'] = f"alr_envs:{_v}" register( id=_env_id, @@ -595,16 +571,12 @@ for _v in _versions: ## HopperJump _versions = ['ALRHopperJump-v0', 'ALRHopperJumpRndmJointsDesPos-v0', 'ALRHopperJumpRndmJointsDesPosStepBased-v0', 'ALRHopperJumpOnBox-v0', 'ALRHopperThrow-v0', 'ALRHopperThrowInBasket-v0'] - +# TODO: Check if all environments work with the same MPWrapper for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) - kwargs_dict_hopper_jump_promp['wrappers'].append('TODO') # TODO - kwargs_dict_hopper_jump_promp['movement_primitives_kwargs']['action_dim'] = 3 - kwargs_dict_hopper_jump_promp['phase_generator_kwargs']['tau'] = 2 - kwargs_dict_hopper_jump_promp['controller_kwargs']['p_gains'] = np.ones(3) - kwargs_dict_hopper_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(3) + kwargs_dict_hopper_jump_promp['wrappers'].append(mujoco.hopper_jump.MPWrapper) kwargs_dict_hopper_jump_promp['name'] = f"alr_envs:{_v}" register( id=_env_id, @@ -622,11 +594,7 @@ for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) - kwargs_dict_walker2d_jump_promp['wrappers'].append('TODO') # TODO - kwargs_dict_walker2d_jump_promp['movement_primitives_kwargs']['action_dim'] = 6 - kwargs_dict_walker2d_jump_promp['phase_generator_kwargs']['tau'] = 2.4 - kwargs_dict_walker2d_jump_promp['controller_kwargs']['p_gains'] = np.ones(6) - kwargs_dict_walker2d_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(6) + kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper) kwargs_dict_walker2d_jump_promp['name'] = f"alr_envs:{_v}" register( id=_env_id, diff --git a/alr_envs/alr/classic_control/viapoint_reacher/__init__.py b/alr_envs/alr/classic_control/viapoint_reacher/__init__.py index 989b5a9..a919c3a 100644 --- a/alr_envs/alr/classic_control/viapoint_reacher/__init__.py +++ b/alr_envs/alr/classic_control/viapoint_reacher/__init__.py @@ -1 +1 @@ -from .mp_wrapper import MPWrapper \ No newline at end of file +from new_mp_wrapper import MPWrapper diff --git a/alr_envs/alr/mujoco/ant_jump/__init__.py b/alr_envs/alr/mujoco/ant_jump/__init__.py index 5d15867..8a04a02 100644 --- a/alr_envs/alr/mujoco/ant_jump/__init__.py +++ b/alr_envs/alr/mujoco/ant_jump/__init__.py @@ -1,2 +1 @@ -from .mp_wrapper import MPWrapper -from .new_mp_wrapper import NewMPWrapper +from .new_mp_wrapper import MPWrapper diff --git a/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py index c12aa56..0886065 100644 --- a/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py @@ -5,7 +5,7 @@ import numpy as np from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class NewMPWrapper(RawInterfaceWrapper): +class MPWrapper(RawInterfaceWrapper): def get_context_mask(self): return np.hstack([ diff --git a/alr_envs/alr/mujoco/beerpong/__init__.py b/alr_envs/alr/mujoco/beerpong/__init__.py index 18e33ce..8a04a02 100644 --- a/alr_envs/alr/mujoco/beerpong/__init__.py +++ b/alr_envs/alr/mujoco/beerpong/__init__.py @@ -1,2 +1 @@ -from .mp_wrapper import MPWrapper -from .new_mp_wrapper import NewMPWrapper \ No newline at end of file +from .new_mp_wrapper import MPWrapper diff --git a/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py b/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py index 0df1a7c..2969b82 100644 --- a/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py @@ -5,7 +5,7 @@ import numpy as np from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class NewMPWrapper(RawInterfaceWrapper): +class MPWrapper(RawInterfaceWrapper): @property def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.sim.data.qpos[0:7].copy() diff --git a/alr_envs/alr/mujoco/half_cheetah_jump/__init__.py b/alr_envs/alr/mujoco/half_cheetah_jump/__init__.py index c5e6d2f..8a04a02 100644 --- a/alr_envs/alr/mujoco/half_cheetah_jump/__init__.py +++ b/alr_envs/alr/mujoco/half_cheetah_jump/__init__.py @@ -1 +1 @@ -from .mp_wrapper import MPWrapper +from .new_mp_wrapper import MPWrapper diff --git a/alr_envs/alr/mujoco/hopper_jump/__init__.py b/alr_envs/alr/mujoco/hopper_jump/__init__.py index fbffe48..8a04a02 100644 --- a/alr_envs/alr/mujoco/hopper_jump/__init__.py +++ b/alr_envs/alr/mujoco/hopper_jump/__init__.py @@ -1,2 +1 @@ -from .mp_wrapper import MPWrapper, HighCtxtMPWrapper -from .new_mp_wrapper import NewMPWrapper, NewHighCtxtMPWrapper +from .new_mp_wrapper import MPWrapper diff --git a/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py index ccd8f76..b919b22 100644 --- a/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py @@ -3,7 +3,7 @@ from typing import Union, Tuple import numpy as np -class NewMPWrapper(BlackBoxWrapper): +class MPWrapper(BlackBoxWrapper): @property def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.sim.data.qpos[3:6].copy() @@ -30,7 +30,7 @@ class NewMPWrapper(BlackBoxWrapper): ]) -class NewHighCtxtMPWrapper(NewMPWrapper): +class NewHighCtxtMPWrapper(MPWrapper): def get_context_mask(self): return np.hstack([ [False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position diff --git a/alr_envs/alr/mujoco/reacher/__init__.py b/alr_envs/alr/mujoco/reacher/__init__.py index c1a25d3..8a04a02 100644 --- a/alr_envs/alr/mujoco/reacher/__init__.py +++ b/alr_envs/alr/mujoco/reacher/__init__.py @@ -1,2 +1 @@ -from .mp_wrapper import MPWrapper -from .new_mp_wrapper import MPWrapper as NewMPWrapper \ No newline at end of file +from .new_mp_wrapper import MPWrapper diff --git a/alr_envs/alr/mujoco/walker_2d_jump/__init__.py b/alr_envs/alr/mujoco/walker_2d_jump/__init__.py index 989b5a9..8a04a02 100644 --- a/alr_envs/alr/mujoco/walker_2d_jump/__init__.py +++ b/alr_envs/alr/mujoco/walker_2d_jump/__init__.py @@ -1 +1 @@ -from .mp_wrapper import MPWrapper \ No newline at end of file +from .new_mp_wrapper import MPWrapper