Ported dmc envs to mp-config
This commit is contained in:
parent
d48836368c
commit
e743663018
@ -1,247 +1,61 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from gymnasium.wrappers import FlattenObservation
|
||||
from gymnasium.envs.registration import register
|
||||
|
||||
from ..envs.registry import register
|
||||
|
||||
from . import manipulation, suite
|
||||
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
from gymnasium.envs.registration import register
|
||||
|
||||
DEFAULT_BB_DICT_ProMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [FlattenObservation],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'promp'
|
||||
},
|
||||
"phase_generator_kwargs": {
|
||||
'phase_generator_type': 'linear'
|
||||
},
|
||||
"controller_kwargs": {
|
||||
'controller_type': 'motor',
|
||||
"p_gains": 50.,
|
||||
"d_gains": 1.,
|
||||
},
|
||||
"basis_generator_kwargs": {
|
||||
'basis_generator_type': 'zero_rbf',
|
||||
'num_basis': 5,
|
||||
'num_basis_zero_start': 1
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_BB_DICT_DMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [FlattenObservation],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'dmp'
|
||||
},
|
||||
"phase_generator_kwargs": {
|
||||
'phase_generator_type': 'exp'
|
||||
},
|
||||
"controller_kwargs": {
|
||||
'controller_type': 'motor',
|
||||
"p_gains": 50.,
|
||||
"d_gains": 1.,
|
||||
},
|
||||
"basis_generator_kwargs": {
|
||||
'basis_generator_type': 'rbf',
|
||||
'num_basis': 5
|
||||
}
|
||||
}
|
||||
|
||||
# DeepMind Control Suite (DMC)
|
||||
kwargs_dict_bic_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_bic_dmp['name'] = f"dm_control/ball_in_cup-catch-v0"
|
||||
kwargs_dict_bic_dmp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
||||
# bandwidth_factor=2
|
||||
kwargs_dict_bic_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
kwargs_dict_bic_dmp['trajectory_generator_kwargs']['weight_scale'] = 10 # TODO: weight scale 1, but goal scale 0.1
|
||||
register(
|
||||
id=f'dmc_ball_in_cup-catch_dmp-v0',
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_bic_dmp
|
||||
id=f"dm_control/ball_in_cup-catch-v0",
|
||||
register_step_based=False,
|
||||
mp_wrapper=suite.ball_in_cup.MPWrapper,
|
||||
add_mp_types=['DMP', 'ProMP'],
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
|
||||
|
||||
kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_bic_promp['name'] = f"dm_control/ball_in_cup-catch-v0"
|
||||
kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
||||
register(
|
||||
id=f'dmc_ball_in_cup-catch_promp-v0',
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_bic_promp
|
||||
id=f"dm_control/reacher-easy-v0",
|
||||
register_step_based=False,
|
||||
mp_wrapper=suite.reacher.MPWrapper,
|
||||
add_mp_types=['DMP', 'ProMP'],
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
|
||||
|
||||
kwargs_dict_reacher_easy_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_easy_dmp['name'] = f"dm_control/reacher-easy-v0"
|
||||
kwargs_dict_reacher_easy_dmp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
# bandwidth_factor=2
|
||||
kwargs_dict_reacher_easy_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
# TODO: weight scale 50, but goal scale 0.1
|
||||
kwargs_dict_reacher_easy_dmp['trajectory_generator_kwargs']['weight_scale'] = 500
|
||||
register(
|
||||
id=f'dmc_reacher-easy_dmp-v0',
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_bic_dmp
|
||||
id=f"dm_control/reacher-hard-v0",
|
||||
register_step_based=False,
|
||||
mp_wrapper=suite.reacher.MPWrapper,
|
||||
add_mp_types=['DMP', 'ProMP'],
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
|
||||
|
||||
kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_easy_promp['name'] = f"dm_control/reacher-easy-v0"
|
||||
kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
register(
|
||||
id=f'dmc_reacher-easy_promp-v0',
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_reacher_easy_promp
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
|
||||
|
||||
kwargs_dict_reacher_hard_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_hard_dmp['name'] = f"dm_control/reacher-hard-v0"
|
||||
kwargs_dict_reacher_hard_dmp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_reacher_hard_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
# TODO: weight scale 50, but goal scale 0.1
|
||||
kwargs_dict_reacher_hard_dmp['trajectory_generator_kwargs']['weight_scale'] = 500
|
||||
register(
|
||||
id=f'dmc_reacher-hard_dmp-v0',
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_reacher_hard_dmp
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
|
||||
|
||||
kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_hard_promp['name'] = f"dm_control/reacher-hard-v0"
|
||||
kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
register(
|
||||
id=f'dmc_reacher-hard_promp-v0',
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_reacher_hard_promp
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-hard_promp-v0")
|
||||
|
||||
_dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
|
||||
|
||||
for _task in _dmc_cartpole_tasks:
|
||||
_env_id = f'dmc_cartpole-{_task}_dmp-v0'
|
||||
kwargs_dict_cartpole_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole_dmp['name'] = f"dm_control/cartpole-{_task}-v0"
|
||||
kwargs_dict_cartpole_dmp['wrappers'].append(suite.cartpole.MPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_cartpole_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
# TODO: weight scale 50, but goal scale 0.1
|
||||
kwargs_dict_cartpole_dmp['trajectory_generator_kwargs']['weight_scale'] = 500
|
||||
kwargs_dict_cartpole_dmp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole_dmp['controller_kwargs']['d_gains'] = 10
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_cartpole_dmp
|
||||
id=f'dmc_cartpole-{_task}_dmp-v0',
|
||||
register_step_based=False,
|
||||
mp_wrapper=suite.cartpole.MPWrapper,
|
||||
add_mp_types=['DMP', 'ProMP'],
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
_env_id = f'dmc_cartpole-{_task}_promp-v0'
|
||||
kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole_promp['name'] = f"dm_control/cartpole-{_task}-v0"
|
||||
kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper)
|
||||
kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole_promp['controller_kwargs']['d_gains'] = 10
|
||||
kwargs_dict_cartpole_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_cartpole_promp
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole2poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole2poles_dmp['name'] = f"dm_control/cartpole-two_poles-v0"
|
||||
kwargs_dict_cartpole2poles_dmp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_cartpole2poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
# TODO: weight scale 50, but goal scale 0.1
|
||||
kwargs_dict_cartpole2poles_dmp['trajectory_generator_kwargs']['weight_scale'] = 500
|
||||
kwargs_dict_cartpole2poles_dmp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole2poles_dmp['controller_kwargs']['d_gains'] = 10
|
||||
_env_id = f'dmc_cartpole-two_poles_dmp-v0'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_cartpole2poles_dmp
|
||||
id=f"dm_control/cartpole-two_poles-v0",
|
||||
register_step_based=False,
|
||||
mp_wrapper=suite.cartpole.TwoPolesMPWrapper,
|
||||
add_mp_types=['DMP', 'ProMP'],
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole2poles_promp['name'] = f"dm_control/cartpole-two_poles-v0"
|
||||
kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
||||
kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole2poles_promp['controller_kwargs']['d_gains'] = 10
|
||||
kwargs_dict_cartpole2poles_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
_env_id = f'dmc_cartpole-two_poles_promp-v0'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_cartpole2poles_promp
|
||||
id=f"dm_control/cartpole-three_poles-v0",
|
||||
register_step_based=False,
|
||||
mp_wrapper=suite.cartpole.ThreePolesMPWrapper,
|
||||
add_mp_types=['DMP', 'ProMP'],
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole3poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole3poles_dmp['name'] = f"dm_control/cartpole-three_poles-v0"
|
||||
kwargs_dict_cartpole3poles_dmp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_cartpole3poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
# TODO: weight scale 50, but goal scale 0.1
|
||||
kwargs_dict_cartpole3poles_dmp['trajectory_generator_kwargs']['weight_scale'] = 500
|
||||
kwargs_dict_cartpole3poles_dmp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole3poles_dmp['controller_kwargs']['d_gains'] = 10
|
||||
_env_id = f'dmc_cartpole-three_poles_dmp-v0'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_cartpole3poles_dmp
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole3poles_promp['name'] = f"dm_control/cartpole-three_poles-v0"
|
||||
kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
||||
kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole3poles_promp['controller_kwargs']['d_gains'] = 10
|
||||
kwargs_dict_cartpole3poles_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
_env_id = f'dmc_cartpole-three_poles_promp-v0'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_cartpole3poles_promp
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# DeepMind Manipulation
|
||||
kwargs_dict_mani_reach_site_features_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_mani_reach_site_features_dmp['name'] = f"dm_control/reach_site_features-v0"
|
||||
kwargs_dict_mani_reach_site_features_dmp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
||||
kwargs_dict_mani_reach_site_features_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
# TODO: weight scale 50, but goal scale 0.1
|
||||
kwargs_dict_mani_reach_site_features_dmp['trajectory_generator_kwargs']['weight_scale'] = 500
|
||||
kwargs_dict_mani_reach_site_features_dmp['controller_kwargs']['controller_type'] = 'velocity'
|
||||
register(
|
||||
id=f'dmc_manipulation-reach_site_dmp-v0',
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_mani_reach_site_features_dmp
|
||||
id=f"dm_control/reach_site_features-v0",
|
||||
register_step_based=False,
|
||||
mp_wrapper=manipulation.reach_site.MPWrapper,
|
||||
add_mp_types=['DMP', 'ProMP'],
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
|
||||
|
||||
kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_mani_reach_site_features_promp['name'] = f"dm_control/reach_site_features-v0"
|
||||
kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
||||
kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
kwargs_dict_mani_reach_site_features_promp['controller_kwargs']['controller_type'] = 'velocity'
|
||||
register(
|
||||
id=f'dmc_manipulation-reach_site_promp-v0',
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_mani_reach_site_features_promp
|
||||
)
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_manipulation-reach_site_promp-v0")
|
||||
|
@ -6,6 +6,30 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
|
||||
|
||||
class MPWrapper(RawInterfaceWrapper):
|
||||
mp_config = {
|
||||
'ProMP': {
|
||||
'controller_kwargs': {
|
||||
'controller_type': 'velocity',
|
||||
'p_gains': 50.0,
|
||||
},
|
||||
'trajectory_generator_kwargs': {
|
||||
'weight_scale': 0.2,
|
||||
},
|
||||
},
|
||||
'DMP': {
|
||||
'controller_kwargs': {
|
||||
'controller_type': 'velocity',
|
||||
'p_gains': 50.0,
|
||||
},
|
||||
'phase_generator': {
|
||||
'alpha_phase': 2,
|
||||
},
|
||||
'trajectory_generator_kwargs': {
|
||||
'weight_scale': 500,
|
||||
},
|
||||
},
|
||||
'ProDMP': {},
|
||||
}
|
||||
|
||||
@property
|
||||
def context_mask(self) -> np.ndarray:
|
||||
|
@ -6,6 +6,25 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
|
||||
|
||||
class MPWrapper(RawInterfaceWrapper):
|
||||
mp_config = {
|
||||
'ProMP': {
|
||||
'controller_kwargs': {
|
||||
'p_gains': 50.0,
|
||||
},
|
||||
},
|
||||
'DMP': {
|
||||
'controller_kwargs': {
|
||||
'p_gains': 50.0,
|
||||
},
|
||||
'phase_generator': {
|
||||
'alpha_phase': 2,
|
||||
},
|
||||
'trajectory_generator_kwargs': {
|
||||
'weight_scale': 10
|
||||
},
|
||||
},
|
||||
'ProDMP': {},
|
||||
}
|
||||
|
||||
@property
|
||||
def context_mask(self) -> np.ndarray:
|
||||
|
@ -6,6 +6,30 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
|
||||
|
||||
class MPWrapper(RawInterfaceWrapper):
|
||||
mp_config = {
|
||||
'ProMP': {
|
||||
'controller_kwargs': {
|
||||
'p_gains': 10,
|
||||
'd_gains': 10,
|
||||
},
|
||||
'trajectory_generator_kwargs': {
|
||||
'weight_scale': 0.2,
|
||||
},
|
||||
},
|
||||
'DMP': {
|
||||
'controller_kwargs': {
|
||||
'p_gains': 10,
|
||||
'd_gains': 10,
|
||||
},
|
||||
'phase_generator': {
|
||||
'alpha_phase': 2,
|
||||
},
|
||||
'trajectory_generator_kwargs': {
|
||||
'weight_scale': 500,
|
||||
},
|
||||
},
|
||||
'ProDMP': {},
|
||||
}
|
||||
|
||||
def __init__(self, env, n_poles: int = 1):
|
||||
self.n_poles = n_poles
|
||||
|
@ -6,6 +6,26 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
|
||||
|
||||
class MPWrapper(RawInterfaceWrapper):
|
||||
mp_config = {
|
||||
'ProMP': {
|
||||
'controller_kwargs': {
|
||||
'p_gains': 50.0,
|
||||
'weight_scale': 0.2,
|
||||
},
|
||||
},
|
||||
'DMP': {
|
||||
'controller_kwargs': {
|
||||
'p_gains': 50.0,
|
||||
},
|
||||
'phase_generator': {
|
||||
'alpha_phase': 2,
|
||||
},
|
||||
'trajectory_generator_kwargs': {
|
||||
'weight_scale': 500,
|
||||
},
|
||||
},
|
||||
'ProDMP': {},
|
||||
}
|
||||
|
||||
@property
|
||||
def context_mask(self) -> np.ndarray:
|
||||
|
Loading…
Reference in New Issue
Block a user