diff --git a/fancy_gym/dmc/__init__.py b/fancy_gym/dmc/__init__.py index 397e6fa..4abcfe3 100644 --- a/fancy_gym/dmc/__init__.py +++ b/fancy_gym/dmc/__init__.py @@ -1,10 +1,10 @@ +from gym.envs.registration import register from copy import deepcopy from . import manipulation, suite ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []} -from gym.envs.registration import register DEFAULT_BB_DICT_ProMP = { "name": 'EnvName', @@ -61,7 +61,7 @@ register( ) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0") -kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP) +kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch" kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper) register( @@ -85,7 +85,7 @@ register( ) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0") -kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP) +kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy" kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper) kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 @@ -110,7 +110,7 @@ register( ) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0") -kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP) +kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard" kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper) kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 @@ -142,7 +142,7 @@ for _task in _dmc_cartpole_tasks: ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) _env_id = f'dmc_cartpole-{_task}_promp-v0' - kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP) + kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}" kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper) kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10 @@ -172,7 +172,7 @@ register( ) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) -kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP) +kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles" kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper) kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10 @@ -203,7 +203,7 @@ register( ) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) -kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP) +kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles" kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper) kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10 @@ -232,7 +232,7 @@ register( ) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0") -kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP) +kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features" kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper) kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2