adjust env registries in __init__

This commit is contained in:
Onur 2022-06-30 14:55:34 +02:00
parent 3273f455c5
commit f31d85451f
11 changed files with 29 additions and 65 deletions

View File

@ -25,21 +25,13 @@ ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
DEFAULT_MP_ENV_DICT = {
"name": 'EnvName',
"wrappers": [],
# TODO move scale to traj_gen
"ep_wrapper_kwargs": {
"weight_scale": 1
"traj_gen_kwargs": {
"weight_scale": 1,
'movement_primitives_type': 'promp'
},
# TODO traj_gen_kwargs
# TODO remove action_dim
"movement_primitives_kwargs": {
'movement_primitives_type': 'promp',
'action_dim': 7
},
# TODO remove tau
"phase_generator_kwargs": {
'phase_generator_type': 'linear',
'delay': 0,
'tau': 1.5, # initial value
'learn_tau': False,
'learn_delay': False
},
@ -51,7 +43,7 @@ DEFAULT_MP_ENV_DICT = {
"basis_generator_kwargs": {
'basis_generator_type': 'zero_rbf',
'num_basis': 5,
'num_basis_zero_start': 2 # TODO: Change to 1
'num_basis_zero_start': 1
}
}
@ -370,9 +362,7 @@ for _v in _versions:
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_simple_reacher_promp['wrappers'].append('TODO') # TODO
kwargs_dict_simple_reacher_promp['movement_primitives_kwargs']['action_dim'] = 2 if "long" not in _v.lower() else 5
kwargs_dict_simple_reacher_promp['phase_generator_kwargs']['tau'] = 2
kwargs_dict_simple_reacher_promp['wrappers'].append(classic_control.simple_reacher.MPWrapper)
kwargs_dict_simple_reacher_promp['controller_kwargs']['p_gains'] = 0.6
kwargs_dict_simple_reacher_promp['controller_kwargs']['d_gains'] = 0.075
kwargs_dict_simple_reacher_promp['name'] = _env_id
@ -405,7 +395,7 @@ register(
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0")
kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_via_point_reacher_promp['wrappers'].append('TODO') # TODO
kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper)
kwargs_dict_via_point_reacher_promp['controller_kwargs']['controller_type'] = 'velocity'
kwargs_dict_via_point_reacher_promp['name'] = "ViaPointReacherProMP-v0"
register(
@ -444,12 +434,9 @@ for _v in _versions:
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_hole_reacher_promp['wrappers'].append('TODO') # TODO
kwargs_dict_hole_reacher_promp['ep_wrapper_kwargs']['weight_scale'] = 2
# kwargs_dict_hole_reacher_promp['movement_primitives_kwargs']['action_dim'] = 5
# kwargs_dict_hole_reacher_promp['phase_generator_kwargs']['tau'] = 2
kwargs_dict_hole_reacher_promp['wrappers'].append(classic_control.hole_reacher.MPWrapper)
kwargs_dict_hole_reacher_promp['traj_gen_kwargs']['weight_scale'] = 2
kwargs_dict_hole_reacher_promp['controller_kwargs']['controller_type'] = 'velocity'
# kwargs_dict_hole_reacher_promp['basis_generator_kwargs']['num_basis'] = 5
kwargs_dict_hole_reacher_promp['name'] = f"alr_envs:{_v}"
register(
id=_env_id,
@ -489,9 +476,7 @@ for _v in _versions:
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_alr_reacher_promp['wrappers'].append('TODO') # TODO
kwargs_dict_alr_reacher_promp['movement_primitives_kwargs']['action_dim'] = 5 if "long" not in _v.lower() else 7
kwargs_dict_alr_reacher_promp['phase_generator_kwargs']['tau'] = 4
kwargs_dict_alr_reacher_promp['wrappers'].append(mujoco.reacher.MPWrapper)
kwargs_dict_alr_reacher_promp['controller_kwargs']['p_gains'] = 1
kwargs_dict_alr_reacher_promp['controller_kwargs']['d_gains'] = 0.1
kwargs_dict_alr_reacher_promp['name'] = f"alr_envs:{_v}"
@ -509,13 +494,12 @@ for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_bp_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.NewMPWrapper)
kwargs_dict_bp_promp['movement_primitives_kwargs']['action_dim'] = 7
kwargs_dict_bp_promp['phase_generator_kwargs']['tau'] = 0.8
kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.MPWrapper)
kwargs_dict_bp_promp['phase_generator_kwargs']['learn_tau'] = True
kwargs_dict_bp_promp['controller_kwargs']['p_gains'] = np.array([1.5, 5, 2.55, 3, 2., 2, 1.25])
kwargs_dict_bp_promp['controller_kwargs']['d_gains'] = np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125])
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}"
register(
id=_env_id,
@ -530,12 +514,12 @@ for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_bp_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.NewMPWrapper)
kwargs_dict_bp_promp['movement_primitives_kwargs']['action_dim'] = 7
kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.MPWrapper)
kwargs_dict_bp_promp['phase_generator_kwargs']['tau'] = 0.62
kwargs_dict_bp_promp['controller_kwargs']['p_gains'] = np.array([1.5, 5, 2.55, 3, 2., 2, 1.25])
kwargs_dict_bp_promp['controller_kwargs']['d_gains'] = np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125])
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_bp_promp['name'] = f"alr_envs:{_v}"
register(
id=_env_id,
@ -555,11 +539,7 @@ for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.NewMPWrapper)
kwargs_dict_ant_jump_promp['movement_primitives_kwargs']['action_dim'] = 8
kwargs_dict_ant_jump_promp['phase_generator_kwargs']['tau'] = 10
kwargs_dict_ant_jump_promp['controller_kwargs']['p_gains'] = np.ones(8)
kwargs_dict_ant_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(8)
kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper)
kwargs_dict_ant_jump_promp['name'] = f"alr_envs:{_v}"
register(
id=_env_id,
@ -576,11 +556,7 @@ for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.ant_jump.NewMPWrapper)
kwargs_dict_halfcheetah_jump_promp['movement_primitives_kwargs']['action_dim'] = 6
kwargs_dict_halfcheetah_jump_promp['phase_generator_kwargs']['tau'] = 5
kwargs_dict_halfcheetah_jump_promp['controller_kwargs']['p_gains'] = np.ones(6)
kwargs_dict_halfcheetah_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(6)
kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper)
kwargs_dict_halfcheetah_jump_promp['name'] = f"alr_envs:{_v}"
register(
id=_env_id,
@ -595,16 +571,12 @@ for _v in _versions:
## HopperJump
_versions = ['ALRHopperJump-v0', 'ALRHopperJumpRndmJointsDesPos-v0', 'ALRHopperJumpRndmJointsDesPosStepBased-v0',
'ALRHopperJumpOnBox-v0', 'ALRHopperThrow-v0', 'ALRHopperThrowInBasket-v0']
# TODO: Check if all environments work with the same MPWrapper
for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_hopper_jump_promp['wrappers'].append('TODO') # TODO
kwargs_dict_hopper_jump_promp['movement_primitives_kwargs']['action_dim'] = 3
kwargs_dict_hopper_jump_promp['phase_generator_kwargs']['tau'] = 2
kwargs_dict_hopper_jump_promp['controller_kwargs']['p_gains'] = np.ones(3)
kwargs_dict_hopper_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(3)
kwargs_dict_hopper_jump_promp['wrappers'].append(mujoco.hopper_jump.MPWrapper)
kwargs_dict_hopper_jump_promp['name'] = f"alr_envs:{_v}"
register(
id=_env_id,
@ -622,11 +594,7 @@ for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT)
kwargs_dict_walker2d_jump_promp['wrappers'].append('TODO') # TODO
kwargs_dict_walker2d_jump_promp['movement_primitives_kwargs']['action_dim'] = 6
kwargs_dict_walker2d_jump_promp['phase_generator_kwargs']['tau'] = 2.4
kwargs_dict_walker2d_jump_promp['controller_kwargs']['p_gains'] = np.ones(6)
kwargs_dict_walker2d_jump_promp['controller_kwargs']['d_gains'] = 0.1 * np.ones(6)
kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper)
kwargs_dict_walker2d_jump_promp['name'] = f"alr_envs:{_v}"
register(
id=_env_id,

View File

@ -1 +1 @@
from .mp_wrapper import MPWrapper
from new_mp_wrapper import MPWrapper

View File

@ -1,2 +1 @@
from .mp_wrapper import MPWrapper
from .new_mp_wrapper import NewMPWrapper
from .new_mp_wrapper import MPWrapper

View File

@ -5,7 +5,7 @@ import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class NewMPWrapper(RawInterfaceWrapper):
class MPWrapper(RawInterfaceWrapper):
def get_context_mask(self):
return np.hstack([

View File

@ -1,2 +1 @@
from .mp_wrapper import MPWrapper
from .new_mp_wrapper import NewMPWrapper
from .new_mp_wrapper import MPWrapper

View File

@ -5,7 +5,7 @@ import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class NewMPWrapper(RawInterfaceWrapper):
class MPWrapper(RawInterfaceWrapper):
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qpos[0:7].copy()

View File

@ -1 +1 @@
from .mp_wrapper import MPWrapper
from .new_mp_wrapper import MPWrapper

View File

@ -1,2 +1 @@
from .mp_wrapper import MPWrapper, HighCtxtMPWrapper
from .new_mp_wrapper import NewMPWrapper, NewHighCtxtMPWrapper
from .new_mp_wrapper import MPWrapper

View File

@ -3,7 +3,7 @@ from typing import Union, Tuple
import numpy as np
class NewMPWrapper(BlackBoxWrapper):
class MPWrapper(BlackBoxWrapper):
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qpos[3:6].copy()
@ -30,7 +30,7 @@ class NewMPWrapper(BlackBoxWrapper):
])
class NewHighCtxtMPWrapper(NewMPWrapper):
class NewHighCtxtMPWrapper(MPWrapper):
def get_context_mask(self):
return np.hstack([
[False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position

View File

@ -1,2 +1 @@
from .mp_wrapper import MPWrapper
from .new_mp_wrapper import MPWrapper as NewMPWrapper
from .new_mp_wrapper import MPWrapper

View File

@ -1 +1 @@
from .mp_wrapper import MPWrapper
from .new_mp_wrapper import MPWrapper