add ProDMP to register
This commit is contained in:
parent
d73c9bbdbf
commit
187c5f5bb2
@ -52,6 +52,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
||||||
# self.traj_gen.set_mp_times(self.time_steps)
|
# self.traj_gen.set_mp_times(self.time_steps)
|
||||||
self.traj_gen.set_duration(self.duration, self.dt)
|
self.traj_gen.set_duration(self.duration, self.dt)
|
||||||
|
# self.traj_gen.basis_gn.show_basis(plot=True)
|
||||||
|
|
||||||
# reward computation
|
# reward computation
|
||||||
self.reward_aggregation = reward_aggregation
|
self.reward_aggregation = reward_aggregation
|
||||||
|
@ -2,7 +2,7 @@ from copy import deepcopy
|
|||||||
|
|
||||||
from . import manipulation, suite
|
from . import manipulation, suite
|
||||||
|
|
||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||||
|
|
||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
|
|||||||
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
||||||
from .mujoco.box_pushing.box_pushing_env import BoxPushingEnv, MAX_EPISODE_STEPS_BOX_PUSHING
|
from .mujoco.box_pushing.box_pushing_env import BoxPushingEnv, MAX_EPISODE_STEPS_BOX_PUSHING
|
||||||
|
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||||
|
|
||||||
DEFAULT_BB_DICT_ProMP = {
|
DEFAULT_BB_DICT_ProMP = {
|
||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
@ -62,6 +62,36 @@ DEFAULT_BB_DICT_DMP = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DEFAULT_BB_DICT_ProDMP = {
|
||||||
|
"name": 'EnvName',
|
||||||
|
"wrappers": [],
|
||||||
|
"trajectory_generator_kwargs": {
|
||||||
|
'trajectory_generator_type': 'prodmp',
|
||||||
|
'duration': 2.0,
|
||||||
|
'weight_scale': 1.0,
|
||||||
|
},
|
||||||
|
"phase_generator_kwargs": {
|
||||||
|
'phase_generator_type': 'exp',
|
||||||
|
'learn_delay': False,
|
||||||
|
'learn_tau': False,
|
||||||
|
},
|
||||||
|
"controller_kwargs": {
|
||||||
|
'controller_type': 'motor',
|
||||||
|
"p_gains": 1.0,
|
||||||
|
"d_gains": 0.1,
|
||||||
|
},
|
||||||
|
"basis_generator_kwargs": {
|
||||||
|
'basis_generator_type': 'prodmp',
|
||||||
|
'alpha': 10,
|
||||||
|
'num_basis': 5,
|
||||||
|
},
|
||||||
|
"black_box_kwargs": {
|
||||||
|
'replanning_schedule': None,
|
||||||
|
'verbose': 2,
|
||||||
|
'enable_traj_level_reward': False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# Classic Control
|
# Classic Control
|
||||||
## Simple Reacher
|
## Simple Reacher
|
||||||
register(
|
register(
|
||||||
@ -456,7 +486,7 @@ for _v in _versions:
|
|||||||
kwargs_dict_box_pushing_promp['name'] = _v
|
kwargs_dict_box_pushing_promp['name'] = _v
|
||||||
kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
||||||
kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
||||||
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2
|
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3.5 # 3.5, 4 to try
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
@ -465,6 +495,11 @@ for _v in _versions:
|
|||||||
)
|
)
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
for _v in _versions:
|
||||||
|
_name = _v.split("-")
|
||||||
|
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
|
||||||
|
kwargs_dict_box_pushing_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
|
kwargs_dict_box_pushing_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
|
||||||
#
|
#
|
||||||
# ## Walker2DJump
|
# ## Walker2DJump
|
||||||
# _versions = ['Walker2DJump-v0']
|
# _versions = ['Walker2DJump-v0']
|
||||||
|
@ -5,7 +5,7 @@ from gym import register
|
|||||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||||
object_change_mp_wrapper
|
object_change_mp_wrapper
|
||||||
|
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||||
|
|
||||||
# MetaWorld
|
# MetaWorld
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from gym import register
|
|||||||
from . import mujoco
|
from . import mujoco
|
||||||
from .deprecated_needs_gym_robotics import robotics
|
from .deprecated_needs_gym_robotics import robotics
|
||||||
|
|
||||||
ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||||
|
|
||||||
DEFAULT_BB_DICT_ProMP = {
|
DEFAULT_BB_DICT_ProMP = {
|
||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
|
Loading…
Reference in New Issue
Block a user