From e7436630180f9eb7d7c83b2d76c944c4f2f8869d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 30 Jul 2023 17:41:44 +0200 Subject: [PATCH] Ported dmc envs to mp-config --- fancy_gym/dmc/__init__.py | 248 +++--------------- .../dmc/manipulation/reach_site/mp_wrapper.py | 24 ++ fancy_gym/dmc/suite/ball_in_cup/mp_wrapper.py | 19 ++ fancy_gym/dmc/suite/cartpole/mp_wrapper.py | 24 ++ fancy_gym/dmc/suite/reacher/mp_wrapper.py | 20 ++ 5 files changed, 118 insertions(+), 217 deletions(-) diff --git a/fancy_gym/dmc/__init__.py b/fancy_gym/dmc/__init__.py index 5d7466c..28e1a0a 100644 --- a/fancy_gym/dmc/__init__.py +++ b/fancy_gym/dmc/__init__.py @@ -1,247 +1,61 @@ from copy import deepcopy from gymnasium.wrappers import FlattenObservation +from gymnasium.envs.registration import register + +from ..envs.registry import register from . import manipulation, suite -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []} - -from gymnasium.envs.registration import register - -DEFAULT_BB_DICT_ProMP = { - "name": 'EnvName', - "wrappers": [FlattenObservation], - "trajectory_generator_kwargs": { - 'trajectory_generator_type': 'promp' - }, - "phase_generator_kwargs": { - 'phase_generator_type': 'linear' - }, - "controller_kwargs": { - 'controller_type': 'motor', - "p_gains": 50., - "d_gains": 1., - }, - "basis_generator_kwargs": { - 'basis_generator_type': 'zero_rbf', - 'num_basis': 5, - 'num_basis_zero_start': 1 - } -} - -DEFAULT_BB_DICT_DMP = { - "name": 'EnvName', - "wrappers": [FlattenObservation], - "trajectory_generator_kwargs": { - 'trajectory_generator_type': 'dmp' - }, - "phase_generator_kwargs": { - 'phase_generator_type': 'exp' - }, - "controller_kwargs": { - 'controller_type': 'motor', - "p_gains": 50., - "d_gains": 1., - }, - "basis_generator_kwargs": { - 'basis_generator_type': 'rbf', - 'num_basis': 5 - } -} - # DeepMind Control Suite (DMC) -kwargs_dict_bic_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_bic_dmp['name'] = f"dm_control/ball_in_cup-catch-v0" -kwargs_dict_bic_dmp['wrappers'].append(suite.ball_in_cup.MPWrapper) -# bandwidth_factor=2 -kwargs_dict_bic_dmp['phase_generator_kwargs']['alpha_phase'] = 2 -kwargs_dict_bic_dmp['trajectory_generator_kwargs']['weight_scale'] = 10 # TODO: weight scale 1, but goal scale 0.1 register( - id=f'dmc_ball_in_cup-catch_dmp-v0', - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_bic_dmp + id=f"dm_control/ball_in_cup-catch-v0", + register_step_based=False, + mp_wrapper=suite.ball_in_cup.MPWrapper, + add_mp_types=['DMP', 'ProMP'], ) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0") -kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_bic_promp['name'] = f"dm_control/ball_in_cup-catch-v0" -kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper) register( - id=f'dmc_ball_in_cup-catch_promp-v0', - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_bic_promp + id=f"dm_control/reacher-easy-v0", + register_step_based=False, + mp_wrapper=suite.reacher.MPWrapper, + add_mp_types=['DMP', 'ProMP'], ) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0") -kwargs_dict_reacher_easy_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_reacher_easy_dmp['name'] = f"dm_control/reacher-easy-v0" -kwargs_dict_reacher_easy_dmp['wrappers'].append(suite.reacher.MPWrapper) -# bandwidth_factor=2 -kwargs_dict_reacher_easy_dmp['phase_generator_kwargs']['alpha_phase'] = 2 -# TODO: weight scale 50, but goal scale 0.1 -kwargs_dict_reacher_easy_dmp['trajectory_generator_kwargs']['weight_scale'] = 500 register( - id=f'dmc_reacher-easy_dmp-v0', - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_bic_dmp + id=f"dm_control/reacher-hard-v0", + register_step_based=False, + mp_wrapper=suite.reacher.MPWrapper, + add_mp_types=['DMP', 'ProMP'], ) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0") - -kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_reacher_easy_promp['name'] = f"dm_control/reacher-easy-v0" -kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper) -kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 -register( - id=f'dmc_reacher-easy_promp-v0', - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_reacher_easy_promp -) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0") - -kwargs_dict_reacher_hard_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_reacher_hard_dmp['name'] = f"dm_control/reacher-hard-v0" -kwargs_dict_reacher_hard_dmp['wrappers'].append(suite.reacher.MPWrapper) -# bandwidth_factor = 2 -kwargs_dict_reacher_hard_dmp['phase_generator_kwargs']['alpha_phase'] = 2 -# TODO: weight scale 50, but goal scale 0.1 -kwargs_dict_reacher_hard_dmp['trajectory_generator_kwargs']['weight_scale'] = 500 -register( - id=f'dmc_reacher-hard_dmp-v0', - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_reacher_hard_dmp -) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0") - -kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_reacher_hard_promp['name'] = f"dm_control/reacher-hard-v0" -kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper) -kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 -register( - id=f'dmc_reacher-hard_promp-v0', - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_reacher_hard_promp -) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-hard_promp-v0") _dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"] - for _task in _dmc_cartpole_tasks: - _env_id = f'dmc_cartpole-{_task}_dmp-v0' - kwargs_dict_cartpole_dmp = deepcopy(DEFAULT_BB_DICT_DMP) - kwargs_dict_cartpole_dmp['name'] = f"dm_control/cartpole-{_task}-v0" - kwargs_dict_cartpole_dmp['wrappers'].append(suite.cartpole.MPWrapper) - # bandwidth_factor = 2 - kwargs_dict_cartpole_dmp['phase_generator_kwargs']['alpha_phase'] = 2 - # TODO: weight scale 50, but goal scale 0.1 - kwargs_dict_cartpole_dmp['trajectory_generator_kwargs']['weight_scale'] = 500 - kwargs_dict_cartpole_dmp['controller_kwargs']['p_gains'] = 10 - kwargs_dict_cartpole_dmp['controller_kwargs']['d_gains'] = 10 register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_cartpole_dmp + id=f'dmc_cartpole-{_task}_dmp-v0', + register_step_based=False, + mp_wrapper=suite.cartpole.MPWrapper, + add_mp_types=['DMP', 'ProMP'], ) - ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) - _env_id = f'dmc_cartpole-{_task}_promp-v0' - kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP) - kwargs_dict_cartpole_promp['name'] = f"dm_control/cartpole-{_task}-v0" - kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper) - kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10 - kwargs_dict_cartpole_promp['controller_kwargs']['d_gains'] = 10 - kwargs_dict_cartpole_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 - register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_cartpole_promp - ) - ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) - -kwargs_dict_cartpole2poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_cartpole2poles_dmp['name'] = f"dm_control/cartpole-two_poles-v0" -kwargs_dict_cartpole2poles_dmp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper) -# bandwidth_factor = 2 -kwargs_dict_cartpole2poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2 -# TODO: weight scale 50, but goal scale 0.1 -kwargs_dict_cartpole2poles_dmp['trajectory_generator_kwargs']['weight_scale'] = 500 -kwargs_dict_cartpole2poles_dmp['controller_kwargs']['p_gains'] = 10 -kwargs_dict_cartpole2poles_dmp['controller_kwargs']['d_gains'] = 10 -_env_id = f'dmc_cartpole-two_poles_dmp-v0' register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_cartpole2poles_dmp + id=f"dm_control/cartpole-two_poles-v0", + register_step_based=False, + mp_wrapper=suite.cartpole.TwoPolesMPWrapper, + add_mp_types=['DMP', 'ProMP'], ) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) -kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_cartpole2poles_promp['name'] = f"dm_control/cartpole-two_poles-v0" -kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper) -kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10 -kwargs_dict_cartpole2poles_promp['controller_kwargs']['d_gains'] = 10 -kwargs_dict_cartpole2poles_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 -_env_id = f'dmc_cartpole-two_poles_promp-v0' register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_cartpole2poles_promp + id=f"dm_control/cartpole-three_poles-v0", + register_step_based=False, + mp_wrapper=suite.cartpole.ThreePolesMPWrapper, + add_mp_types=['DMP', 'ProMP'], ) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) - -kwargs_dict_cartpole3poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_cartpole3poles_dmp['name'] = f"dm_control/cartpole-three_poles-v0" -kwargs_dict_cartpole3poles_dmp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper) -# bandwidth_factor = 2 -kwargs_dict_cartpole3poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2 -# TODO: weight scale 50, but goal scale 0.1 -kwargs_dict_cartpole3poles_dmp['trajectory_generator_kwargs']['weight_scale'] = 500 -kwargs_dict_cartpole3poles_dmp['controller_kwargs']['p_gains'] = 10 -kwargs_dict_cartpole3poles_dmp['controller_kwargs']['d_gains'] = 10 -_env_id = f'dmc_cartpole-three_poles_dmp-v0' -register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_cartpole3poles_dmp -) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) - -kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_cartpole3poles_promp['name'] = f"dm_control/cartpole-three_poles-v0" -kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper) -kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10 -kwargs_dict_cartpole3poles_promp['controller_kwargs']['d_gains'] = 10 -kwargs_dict_cartpole3poles_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 -_env_id = f'dmc_cartpole-three_poles_promp-v0' -register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_cartpole3poles_promp -) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) # DeepMind Manipulation -kwargs_dict_mani_reach_site_features_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_mani_reach_site_features_dmp['name'] = f"dm_control/reach_site_features-v0" -kwargs_dict_mani_reach_site_features_dmp['wrappers'].append(manipulation.reach_site.MPWrapper) -kwargs_dict_mani_reach_site_features_dmp['phase_generator_kwargs']['alpha_phase'] = 2 -# TODO: weight scale 50, but goal scale 0.1 -kwargs_dict_mani_reach_site_features_dmp['trajectory_generator_kwargs']['weight_scale'] = 500 -kwargs_dict_mani_reach_site_features_dmp['controller_kwargs']['controller_type'] = 'velocity' register( - id=f'dmc_manipulation-reach_site_dmp-v0', - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_mani_reach_site_features_dmp + id=f"dm_control/reach_site_features-v0", + register_step_based=False, + mp_wrapper=manipulation.reach_site.MPWrapper, + add_mp_types=['DMP', 'ProMP'], ) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0") - -kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_mani_reach_site_features_promp['name'] = f"dm_control/reach_site_features-v0" -kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper) -kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 -kwargs_dict_mani_reach_site_features_promp['controller_kwargs']['controller_type'] = 'velocity' -register( - id=f'dmc_manipulation-reach_site_promp-v0', - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_mani_reach_site_features_promp -) -ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_manipulation-reach_site_promp-v0") diff --git a/fancy_gym/dmc/manipulation/reach_site/mp_wrapper.py b/fancy_gym/dmc/manipulation/reach_site/mp_wrapper.py index 908cee1..bc3445a 100644 --- a/fancy_gym/dmc/manipulation/reach_site/mp_wrapper.py +++ b/fancy_gym/dmc/manipulation/reach_site/mp_wrapper.py @@ -6,6 +6,30 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): + mp_config = { + 'ProMP': { + 'controller_kwargs': { + 'controller_type': 'velocity', + 'p_gains': 50.0, + }, + 'trajectory_generator_kwargs': { + 'weight_scale': 0.2, + }, + }, + 'DMP': { + 'controller_kwargs': { + 'controller_type': 'velocity', + 'p_gains': 50.0, + }, + 'phase_generator': { + 'alpha_phase': 2, + }, + 'trajectory_generator_kwargs': { + 'weight_scale': 500, + }, + }, + 'ProDMP': {}, + } @property def context_mask(self) -> np.ndarray: diff --git a/fancy_gym/dmc/suite/ball_in_cup/mp_wrapper.py b/fancy_gym/dmc/suite/ball_in_cup/mp_wrapper.py index 94f9041..aef9896 100644 --- a/fancy_gym/dmc/suite/ball_in_cup/mp_wrapper.py +++ b/fancy_gym/dmc/suite/ball_in_cup/mp_wrapper.py @@ -6,6 +6,25 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): + mp_config = { + 'ProMP': { + 'controller_kwargs': { + 'p_gains': 50.0, + }, + }, + 'DMP': { + 'controller_kwargs': { + 'p_gains': 50.0, + }, + 'phase_generator': { + 'alpha_phase': 2, + }, + 'trajectory_generator_kwargs': { + 'weight_scale': 10 + }, + }, + 'ProDMP': {}, + } @property def context_mask(self) -> np.ndarray: diff --git a/fancy_gym/dmc/suite/cartpole/mp_wrapper.py b/fancy_gym/dmc/suite/cartpole/mp_wrapper.py index 85afa83..9373cf2 100644 --- a/fancy_gym/dmc/suite/cartpole/mp_wrapper.py +++ b/fancy_gym/dmc/suite/cartpole/mp_wrapper.py @@ -6,6 +6,30 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): + mp_config = { + 'ProMP': { + 'controller_kwargs': { + 'p_gains': 10, + 'd_gains': 10, + }, + 'trajectory_generator_kwargs': { + 'weight_scale': 0.2, + }, + }, + 'DMP': { + 'controller_kwargs': { + 'p_gains': 10, + 'd_gains': 10, + }, + 'phase_generator': { + 'alpha_phase': 2, + }, + 'trajectory_generator_kwargs': { + 'weight_scale': 500, + }, + }, + 'ProDMP': {}, + } def __init__(self, env, n_poles: int = 1): self.n_poles = n_poles diff --git a/fancy_gym/dmc/suite/reacher/mp_wrapper.py b/fancy_gym/dmc/suite/reacher/mp_wrapper.py index 2d0aee5..5fcf5a7 100644 --- a/fancy_gym/dmc/suite/reacher/mp_wrapper.py +++ b/fancy_gym/dmc/suite/reacher/mp_wrapper.py @@ -6,6 +6,26 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): + mp_config = { + 'ProMP': { + 'controller_kwargs': { + 'p_gains': 50.0, + 'weight_scale': 0.2, + }, + }, + 'DMP': { + 'controller_kwargs': { + 'p_gains': 50.0, + }, + 'phase_generator': { + 'alpha_phase': 2, + }, + 'trajectory_generator_kwargs': { + 'weight_scale': 500, + }, + }, + 'ProDMP': {}, + } @property def context_mask(self) -> np.ndarray: