add ProDMP to register

This commit is contained in:
Hongyi Zhou 2022-10-13 10:57:00 +02:00
parent d73c9bbdbf
commit 187c5f5bb2
5 changed files with 41 additions and 5 deletions

View File

@ -52,6 +52,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
# self.traj_gen.set_mp_times(self.time_steps)
self.traj_gen.set_duration(self.duration, self.dt)
# self.traj_gen.basis_gn.show_basis(plot=True)
# reward computation
self.reward_aggregation = reward_aggregation

View File

@ -2,7 +2,7 @@ from copy import deepcopy
from . import manipulation, suite
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
from gym.envs.registration import register

View File

@ -18,7 +18,7 @@ from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
from .mujoco.box_pushing.box_pushing_env import BoxPushingEnv, MAX_EPISODE_STEPS_BOX_PUSHING
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
DEFAULT_BB_DICT_ProMP = {
"name": 'EnvName',
@ -62,6 +62,36 @@ DEFAULT_BB_DICT_DMP = {
}
}
DEFAULT_BB_DICT_ProDMP = {
"name": 'EnvName',
"wrappers": [],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'prodmp',
'duration': 2.0,
'weight_scale': 1.0,
},
"phase_generator_kwargs": {
'phase_generator_type': 'exp',
'learn_delay': False,
'learn_tau': False,
},
"controller_kwargs": {
'controller_type': 'motor',
"p_gains": 1.0,
"d_gains": 0.1,
},
"basis_generator_kwargs": {
'basis_generator_type': 'prodmp',
'alpha': 10,
'num_basis': 5,
},
"black_box_kwargs": {
'replanning_schedule': None,
'verbose': 2,
'enable_traj_level_reward': False,
}
}
# Classic Control
## Simple Reacher
register(
@ -456,7 +486,7 @@ for _v in _versions:
kwargs_dict_box_pushing_promp['name'] = _v
kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3.5 # 3.5, 4 to try
register(
id=_env_id,
@ -465,6 +495,11 @@ for _v in _versions:
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
kwargs_dict_box_pushing_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
kwargs_dict_box_pushing_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
#
# ## Walker2DJump
# _versions = ['Walker2DJump-v0']

View File

@ -5,7 +5,7 @@ from gym import register
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
object_change_mp_wrapper
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
# MetaWorld

View File

@ -5,7 +5,7 @@ from gym import register
from . import mujoco
from .deprecated_needs_gym_robotics import robotics
ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
DEFAULT_BB_DICT_ProMP = {
"name": 'EnvName',