Compare commits

...

2 Commits

View File

@ -1,10 +1,10 @@
from gym.envs.registration import register
from copy import deepcopy from copy import deepcopy
from . import manipulation, suite from . import manipulation, suite
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []} ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
from gym.envs.registration import register
DEFAULT_BB_DICT_ProMP = { DEFAULT_BB_DICT_ProMP = {
"name": 'EnvName', "name": 'EnvName',
@ -61,7 +61,7 @@ register(
) )
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0") ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch" kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch"
kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper) kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper)
register( register(
@ -85,7 +85,7 @@ register(
) )
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0") ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy" kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy"
kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper) kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper)
kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
@ -110,7 +110,7 @@ register(
) )
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0") ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard" kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard"
kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper) kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper)
kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
@ -142,7 +142,7 @@ for _task in _dmc_cartpole_tasks:
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'dmc_cartpole-{_task}_promp-v0' _env_id = f'dmc_cartpole-{_task}_promp-v0'
kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}" kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}"
kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper) kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper)
kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10 kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10
@ -172,7 +172,7 @@ register(
) )
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles" kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles"
kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper) kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10 kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10
@ -203,7 +203,7 @@ register(
) )
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles" kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles"
kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper) kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10 kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10
@ -232,7 +232,7 @@ register(
) )
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0") ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features" kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features"
kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper) kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper)
kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2