add ProDMP to register
This commit is contained in:
parent
d73c9bbdbf
commit
187c5f5bb2
@ -52,6 +52,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
||||
# self.traj_gen.set_mp_times(self.time_steps)
|
||||
self.traj_gen.set_duration(self.duration, self.dt)
|
||||
# self.traj_gen.basis_gn.show_basis(plot=True)
|
||||
|
||||
# reward computation
|
||||
self.reward_aggregation = reward_aggregation
|
||||
|
@ -2,7 +2,7 @@ from copy import deepcopy
|
||||
|
||||
from . import manipulation, suite
|
||||
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
from gym.envs.registration import register
|
||||
|
||||
|
@ -18,7 +18,7 @@ from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
|
||||
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
||||
from .mujoco.box_pushing.box_pushing_env import BoxPushingEnv, MAX_EPISODE_STEPS_BOX_PUSHING
|
||||
|
||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
DEFAULT_BB_DICT_ProMP = {
|
||||
"name": 'EnvName',
|
||||
@ -62,6 +62,36 @@ DEFAULT_BB_DICT_DMP = {
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_BB_DICT_ProDMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'prodmp',
|
||||
'duration': 2.0,
|
||||
'weight_scale': 1.0,
|
||||
},
|
||||
"phase_generator_kwargs": {
|
||||
'phase_generator_type': 'exp',
|
||||
'learn_delay': False,
|
||||
'learn_tau': False,
|
||||
},
|
||||
"controller_kwargs": {
|
||||
'controller_type': 'motor',
|
||||
"p_gains": 1.0,
|
||||
"d_gains": 0.1,
|
||||
},
|
||||
"basis_generator_kwargs": {
|
||||
'basis_generator_type': 'prodmp',
|
||||
'alpha': 10,
|
||||
'num_basis': 5,
|
||||
},
|
||||
"black_box_kwargs": {
|
||||
'replanning_schedule': None,
|
||||
'verbose': 2,
|
||||
'enable_traj_level_reward': False,
|
||||
}
|
||||
}
|
||||
|
||||
# Classic Control
|
||||
## Simple Reacher
|
||||
register(
|
||||
@ -456,7 +486,7 @@ for _v in _versions:
|
||||
kwargs_dict_box_pushing_promp['name'] = _v
|
||||
kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
||||
kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
||||
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2
|
||||
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3.5 # 3.5, 4 to try
|
||||
|
||||
register(
|
||||
id=_env_id,
|
||||
@ -465,6 +495,11 @@ for _v in _versions:
|
||||
)
|
||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
for _v in _versions:
|
||||
_name = _v.split("-")
|
||||
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
|
||||
kwargs_dict_box_pushing_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||
kwargs_dict_box_pushing_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
|
||||
#
|
||||
# ## Walker2DJump
|
||||
# _versions = ['Walker2DJump-v0']
|
||||
|
@ -5,7 +5,7 @@ from gym import register
|
||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||
object_change_mp_wrapper
|
||||
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
# MetaWorld
|
||||
|
||||
|
@ -5,7 +5,7 @@ from gym import register
|
||||
from . import mujoco
|
||||
from .deprecated_needs_gym_robotics import robotics
|
||||
|
||||
ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
DEFAULT_BB_DICT_ProMP = {
|
||||
"name": 'EnvName',
|
||||
|
Loading…
Reference in New Issue
Block a user