adjust env registries in __init__

This commit is contained in:
Onur 2022-06-30 14:55:34 +02:00
parent 3273f455c5
commit f31d85451f
11 changed files with 29 additions and 65 deletions

View File

@ -25,21 +25,13 @@ ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
DEFAULT_MP_ENV_DICT = { DEFAULT_MP_ENV_DICT = {
"name": 'EnvName', "name": 'EnvName',
"wrappers": [], "wrappers": [],
# TODO move scale to traj_gen "traj_gen_kwargs": {
"ep_wrapper_kwargs": { "weight_scale": 1,
"weight_scale": 1 'movement_primitives_type': 'promp'
}, },
# TODO traj_gen_kwargs
# TODO remove action_dim
"movement_primitives_kwargs": {
'movement_primitives_type': 'promp',
'action_dim': 7
},
# TODO remove tau
"phase_generator_kwargs": { "phase_generator_kwargs": {
'phase_generator_type': 'linear', 'phase_generator_type': 'linear',
'delay': 0, 'delay': 0,
'tau': 1.5, # initial value
'learn_tau': False, 'learn_tau': False,
'learn_delay': False 'learn_delay': False
}, },
@ -51,7 +43,7 @@ DEFAULT_MP_ENV_DICT = {
"basis_generator_kwargs": { "basis_generator_kwargs": {
'basis_generator_type': 'zero_rbf', 'basis_generator_type': 'zero_rbf',
'num_basis': 5, 'num_basis': 5,
'num_basis_zero_start': 2 # TODO: Change to 1 'num_basis_zero_start': 1
} }
} }
@ -370,9 +362,7 @@ for _v in _versions:
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_simple_reacher_promp['wrappers'].append('TODO') # TODO kwargs_dict_simple_reacher_promp['wrappers'].append(classic_control.simple_reacher.MPWrapper)
kwargs_dict_simple_reacher_promp['movement_primitives_kwargs']['action_dim'] = 2 if "long" not in _v.lower() else 5
kwargs_dict_simple_reacher_promp['phase_generator_kwargs']['tau'] = 2
kwargs_dict_simple_reacher_promp['controller_kwargs']['p_gains'] = 0.6 kwargs_dict_simple_reacher_promp['controller_kwargs']['p_gains'] = 0.6
kwargs_dict_simple_reacher_promp['controller_kwargs']['d_gains'] = 0.075 kwargs_dict_simple_reacher_promp['controller_kwargs']['d_gains'] = 0.075
kwargs_dict_simple_reacher_promp['name'] = _env_id kwargs_dict_simple_reacher_promp['name'] = _env_id
@ -405,7 +395,7 @@ register(
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0") ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0")
kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_via_point_reacher_promp['wrappers'].append('TODO') # TODO kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper)
kwargs_dict_via_point_reacher_promp['controller_kwargs']['controller_type'] = 'velocity' kwargs_dict_via_point_reacher_promp['controller_kwargs']['controller_type'] = 'velocity'
kwargs_dict_via_point_reacher_promp['name'] = "ViaPointReacherProMP-v0" kwargs_dict_via_point_reacher_promp['name'] = "ViaPointReacherProMP-v0"
register( register(
@ -444,12 +434,9 @@ for _v in _versions:
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_hole_reacher_promp['wrappers'].append('TODO') # TODO kwargs_dict_hole_reacher_promp['wrappers'].append(classic_control.hole_reacher.MPWrapper)
kwargs_dict_hole_reacher_promp['ep_wrapper_kwargs']['weight_scale'] = 2 kwargs_dict_hole_reacher_promp['traj_gen_kwargs']['weight_scale'] = 2
# kwargs_dict_hole_reacher_promp['movement_primitives_kwargs']['action_dim'] = 5
# kwargs_dict_hole_reacher_promp['phase_generator_kwargs']['tau'] = 2
kwargs_dict_hole_reacher_promp['controller_kwargs']['controller_type'] = 'velocity' kwargs_dict_hole_reacher_promp['controller_kwargs']['controller_type'] = 'velocity'
# kwargs_dict_hole_reacher_promp['basis_generator_kwargs']['num_basis'] = 5
kwargs_dict_hole_reacher_promp['name'] = f"alr_envs:{_v}" kwargs_dict_hole_reacher_promp['name'] = f"alr_envs:{_v}"
register( register(
id=_env_id, id=_env_id,
@ -489,9 +476,7 @@ for _v in _versions:
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_alr_reacher_promp['wrappers'].append('TODO') # TODO kwargs_dict_alr_reacher_promp['wrappers'].append(mujoco.reacher.MPWrapper)
kwargs_dict_alr_reacher_promp['movement_primitives_kwargs']['action_dim'] = 5 if "long" not in _v.lower() else 7
kwargs_dict_alr_reacher_promp['phase_generator_kwargs']['tau'] = 4
kwargs_dict_alr_reacher_promp['controller_kwargs']['p_gains'] = 1 kwargs_dict_alr_reacher_promp['controller_kwargs']['p_gains'] = 1
kwargs_dict_alr_reacher_promp['controller_kwargs']['d_gains'] = 0.1 kwargs_dict_alr_reacher_promp['controller_kwargs']['d_gains'] = 0.1
kwargs_dict_alr_reacher_promp['name'] = f"alr_envs:{_v}" kwargs_dict_alr_reacher_promp['name'] = f"alr_envs:{_v}"
@ -509,13 +494,12 @@ for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_bp_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_bp_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.NewMPWrapper) kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.MPWrapper)
kwargs_dict_bp_promp['movement_primitives_kwargs']['action_dim'] = 7
kwargs_dict_bp_promp['phase_generator_kwargs']['tau'] = 0.8
kwargs_dict_bp_promp['phase_generator_kwargs']['learn_tau'] = True kwargs_dict_bp_promp['phase_generator_kwargs']['learn_tau'] = True
kwargs_dict_bp_promp['controller_kwargs']['p_gains'] = np.array([1.5, 5, 2.55, 3, 2., 2, 1.25]) kwargs_dict_bp_promp['controller_kwargs']['p_gains'] = np.array([1.5, 5, 2.55, 3, 2., 2, 1.25])
kwargs_dict_bp_promp['controller_kwargs']['d_gains'] = np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125]) kwargs_dict_bp_promp['controller_kwargs']['d_gains'] = np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125])
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2 kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}" kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}"
register( register(
id=_env_id, id=_env_id,
@ -530,12 +514,12 @@ for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_bp_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_bp_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.NewMPWrapper) kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.MPWrapper)
kwargs_dict_bp_promp['movement_primitives_kwargs']['action_dim'] = 7
kwargs_dict_bp_promp['phase_generator_kwargs']['tau'] = 0.62 kwargs_dict_bp_promp['phase_generator_kwargs']['tau'] = 0.62
kwargs_dict_bp_promp['controller_kwargs']['p_gains'] = np.array([1.5, 5, 2.55, 3, 2., 2, 1.25]) kwargs_dict_bp_promp['controller_kwargs']['p_gains'] = np.array([1.5, 5, 2.55, 3, 2., 2, 1.25])
kwargs_dict_bp_promp['controller_kwargs']['d_gains'] = np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125]) kwargs_dict_bp_promp['controller_kwargs']['d_gains'] = np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125])
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2 kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}" kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}"
register( register(
id=_env_id, id=_env_id,
@ -555,11 +539,7 @@ for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.NewMPWrapper) kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper)
kwargs_dict_ant_jump_promp['movement_primitives_kwargs']['action_dim'] = 8
kwargs_dict_ant_jump_promp['phase_generator_kwargs']['tau'] = 10
kwargs_dict_ant_jump_promp['controller_kwargs']['p_gains'] = np.ones(8)
kwargs_dict_ant_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(8)
kwargs_dict_ant_jump_promp['name'] = f"alr_envs:{_v}" kwargs_dict_ant_jump_promp['name'] = f"alr_envs:{_v}"
register( register(
id=_env_id, id=_env_id,
@ -576,11 +556,7 @@ for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.ant_jump.NewMPWrapper) kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper)
kwargs_dict_halfcheetah_jump_promp['movement_primitives_kwargs']['action_dim'] = 6
kwargs_dict_halfcheetah_jump_promp['phase_generator_kwargs']['tau'] = 5
kwargs_dict_halfcheetah_jump_promp['controller_kwargs']['p_gains'] = np.ones(6)
kwargs_dict_halfcheetah_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(6)
kwargs_dict_halfcheetah_jump_promp['name'] = f"alr_envs:{_v}" kwargs_dict_halfcheetah_jump_promp['name'] = f"alr_envs:{_v}"
register( register(
id=_env_id, id=_env_id,
@ -595,16 +571,12 @@ for _v in _versions:
## HopperJump ## HopperJump
_versions = ['ALRHopperJump-v0', 'ALRHopperJumpRndmJointsDesPos-v0', 'ALRHopperJumpRndmJointsDesPosStepBased-v0', _versions = ['ALRHopperJump-v0', 'ALRHopperJumpRndmJointsDesPos-v0', 'ALRHopperJumpRndmJointsDesPosStepBased-v0',
'ALRHopperJumpOnBox-v0', 'ALRHopperThrow-v0', 'ALRHopperThrowInBasket-v0'] 'ALRHopperJumpOnBox-v0', 'ALRHopperThrow-v0', 'ALRHopperThrowInBasket-v0']
# TODO: Check if all environments work with the same MPWrapper
for _v in _versions: for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_hopper_jump_promp['wrappers'].append('TODO') # TODO kwargs_dict_hopper_jump_promp['wrappers'].append(mujoco.hopper_jump.MPWrapper)
kwargs_dict_hopper_jump_promp['movement_primitives_kwargs']['action_dim'] = 3
kwargs_dict_hopper_jump_promp['phase_generator_kwargs']['tau'] = 2
kwargs_dict_hopper_jump_promp['controller_kwargs']['p_gains'] = np.ones(3)
kwargs_dict_hopper_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(3)
kwargs_dict_hopper_jump_promp['name'] = f"alr_envs:{_v}" kwargs_dict_hopper_jump_promp['name'] = f"alr_envs:{_v}"
register( register(
id=_env_id, id=_env_id,
@ -622,11 +594,7 @@ for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_walker2d_jump_promp['wrappers'].append('TODO') # TODO kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper)
kwargs_dict_walker2d_jump_promp['movement_primitives_kwargs']['action_dim'] = 6
kwargs_dict_walker2d_jump_promp['phase_generator_kwargs']['tau'] = 2.4
kwargs_dict_walker2d_jump_promp['controller_kwargs']['p_gains'] = np.ones(6)
kwargs_dict_walker2d_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(6)
kwargs_dict_walker2d_jump_promp['name'] = f"alr_envs:{_v}" kwargs_dict_walker2d_jump_promp['name'] = f"alr_envs:{_v}"
register( register(
id=_env_id, id=_env_id,

View File

@ -1 +1 @@
from .mp_wrapper import MPWrapper from new_mp_wrapper import MPWrapper

View File

@ -1,2 +1 @@
from .mp_wrapper import MPWrapper from .new_mp_wrapper import MPWrapper
from .new_mp_wrapper import NewMPWrapper

View File

@ -5,7 +5,7 @@ import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class NewMPWrapper(RawInterfaceWrapper): class MPWrapper(RawInterfaceWrapper):
def get_context_mask(self): def get_context_mask(self):
return np.hstack([ return np.hstack([

View File

@ -1,2 +1 @@
from .mp_wrapper import MPWrapper from .new_mp_wrapper import MPWrapper
from .new_mp_wrapper import NewMPWrapper

View File

@ -5,7 +5,7 @@ import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class NewMPWrapper(RawInterfaceWrapper): class MPWrapper(RawInterfaceWrapper):
@property @property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qpos[0:7].copy() return self.env.sim.data.qpos[0:7].copy()

View File

@ -1 +1 @@
from .mp_wrapper import MPWrapper from .new_mp_wrapper import MPWrapper

View File

@ -1,2 +1 @@
from .mp_wrapper import MPWrapper, HighCtxtMPWrapper from .new_mp_wrapper import MPWrapper
from .new_mp_wrapper import NewMPWrapper, NewHighCtxtMPWrapper

View File

@ -3,7 +3,7 @@ from typing import Union, Tuple
import numpy as np import numpy as np
class NewMPWrapper(BlackBoxWrapper): class MPWrapper(BlackBoxWrapper):
@property @property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qpos[3:6].copy() return self.env.sim.data.qpos[3:6].copy()
@ -30,7 +30,7 @@ class NewMPWrapper(BlackBoxWrapper):
]) ])
class NewHighCtxtMPWrapper(NewMPWrapper): class NewHighCtxtMPWrapper(MPWrapper):
def get_context_mask(self): def get_context_mask(self):
return np.hstack([ return np.hstack([
[False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position [False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position

View File

@ -1,2 +1 @@
from .mp_wrapper import MPWrapper from .new_mp_wrapper import MPWrapper
from .new_mp_wrapper import MPWrapper as NewMPWrapper

View File

@ -1 +1 @@
from .mp_wrapper import MPWrapper from .new_mp_wrapper import MPWrapper