Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
34a16ea5fe | |||
bb42a500ee |
@ -1,10 +1,10 @@
|
|||||||
|
from gym.envs.registration import register
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from . import manipulation, suite
|
from . import manipulation, suite
|
||||||
|
|
||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||||
|
|
||||||
from gym.envs.registration import register
|
|
||||||
|
|
||||||
DEFAULT_BB_DICT_ProMP = {
|
DEFAULT_BB_DICT_ProMP = {
|
||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
@ -61,7 +61,7 @@ register(
|
|||||||
)
|
)
|
||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
|
||||||
|
|
||||||
kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch"
|
kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch"
|
||||||
kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
||||||
register(
|
register(
|
||||||
@ -85,7 +85,7 @@ register(
|
|||||||
)
|
)
|
||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
|
||||||
|
|
||||||
kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy"
|
kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy"
|
||||||
kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper)
|
kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper)
|
||||||
kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||||
@ -110,7 +110,7 @@ register(
|
|||||||
)
|
)
|
||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
|
||||||
|
|
||||||
kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard"
|
kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard"
|
||||||
kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper)
|
kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper)
|
||||||
kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||||
@ -142,7 +142,7 @@ for _task in _dmc_cartpole_tasks:
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'dmc_cartpole-{_task}_promp-v0'
|
_env_id = f'dmc_cartpole-{_task}_promp-v0'
|
||||||
kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}"
|
kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}"
|
||||||
kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper)
|
kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper)
|
||||||
kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10
|
kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10
|
||||||
@ -172,7 +172,7 @@ register(
|
|||||||
)
|
)
|
||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles"
|
kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles"
|
||||||
kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
||||||
kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10
|
kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10
|
||||||
@ -203,7 +203,7 @@ register(
|
|||||||
)
|
)
|
||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles"
|
kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles"
|
||||||
kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
||||||
kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10
|
kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10
|
||||||
@ -232,7 +232,7 @@ register(
|
|||||||
)
|
)
|
||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
|
||||||
|
|
||||||
kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features"
|
kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features"
|
||||||
kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
||||||
kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||||
|
Loading…
Reference in New Issue
Block a user