Ported metaworld to mp-config
This commit is contained in:
parent
e743663018
commit
e63a0a50df
@ -2,7 +2,7 @@ from typing import Iterable, Type, Union, Optional
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from gymnasium import register
|
||||
from ..envs.registry import register
|
||||
|
||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||
object_change_mp_wrapper
|
||||
@ -14,118 +14,24 @@ metaworld_adapter.register_all_ML1()
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
# MetaWorld
|
||||
|
||||
DEFAULT_BB_DICT_ProMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'promp',
|
||||
'weights_scale': 10,
|
||||
},
|
||||
"phase_generator_kwargs": {
|
||||
'phase_generator_type': 'linear'
|
||||
},
|
||||
"controller_kwargs": {
|
||||
'controller_type': 'metaworld',
|
||||
},
|
||||
"basis_generator_kwargs": {
|
||||
'basis_generator_type': 'zero_rbf',
|
||||
'num_basis': 5,
|
||||
'num_basis_zero_start': 1
|
||||
},
|
||||
'black_box_kwargs': {
|
||||
'condition_on_desired': False,
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_BB_DICT_ProDMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'prodmp',
|
||||
'auto_scale_basis': True,
|
||||
'weights_scale': 10,
|
||||
# 'goal_scale': 0.,
|
||||
'disable_goal': True,
|
||||
},
|
||||
"phase_generator_kwargs": {
|
||||
'phase_generator_type': 'exp',
|
||||
# 'alpha_phase' : 3,
|
||||
},
|
||||
"controller_kwargs": {
|
||||
'controller_type': 'metaworld',
|
||||
},
|
||||
"basis_generator_kwargs": {
|
||||
'basis_generator_type': 'prodmp',
|
||||
'num_basis': 5,
|
||||
'alpha': 10
|
||||
},
|
||||
'black_box_kwargs': {
|
||||
'condition_on_desired': False,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||
"plate-slide-side-v2", "plate-slide-back-side-v2"]
|
||||
for _task in _goal_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
|
||||
# ProMP
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
kwargs_dict_goal_change_promp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_goal_change_promp['name'] = f'metaworld:{_task}'
|
||||
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_goal_change_promp
|
||||
id=f'metaworld/{_task}',
|
||||
register_step_based=False,
|
||||
mp_wrapper=goal_change_mp_wrapper.MPWrapper,
|
||||
add_mp_types=['ProMP', 'ProDMP'],
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# ProDMP
|
||||
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||
kwargs_dict_goal_change_prodmp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_goal_change_prodmp['name'] = f'metaworld:{_task}'
|
||||
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_goal_change_prodmp
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||
|
||||
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
||||
for _task in _object_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
|
||||
# ProMP
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
kwargs_dict_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
kwargs_dict_object_change_promp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_object_change_promp['name'] = f'metaworld:{_task}'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_object_change_promp
|
||||
id=f'metaworld/{_task}',
|
||||
register_step_based=False,
|
||||
mp_wrapper=object_change_mp_wrapper.MPWrapper,
|
||||
add_mp_types=['ProMP', 'ProDMP'],
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# ProDMP
|
||||
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||
kwargs_dict_object_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||
kwargs_dict_object_change_prodmp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_object_change_prodmp['name'] = f'metaworld:{_task}'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_object_change_prodmp
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||
|
||||
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
||||
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
||||
@ -139,62 +45,18 @@ _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press
|
||||
"shelf-place-v2", "sweep-v2", "window-open-v2", "window-close-v2"
|
||||
]
|
||||
for _task in _goal_and_object_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
|
||||
# ProMP
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_and_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
kwargs_dict_goal_and_object_change_promp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_goal_and_object_change_promp['name'] = f'metaworld:{_task}'
|
||||
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_goal_and_object_change_promp
|
||||
id=f'metaworld/{_task}',
|
||||
register_step_based=False,
|
||||
mp_wrapper=goal_object_change_mp_wrapper.MPWrapper,
|
||||
add_mp_types=['ProMP', 'ProDMP'],
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# ProDMP
|
||||
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_and_object_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||
kwargs_dict_goal_and_object_change_prodmp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_goal_and_object_change_prodmp['name'] = f'metaworld:{_task}'
|
||||
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_goal_and_object_change_prodmp
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||
|
||||
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
||||
for _task in _goal_and_endeffector_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
|
||||
# ProMP
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_and_endeffector_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
kwargs_dict_goal_and_endeffector_change_promp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_goal_and_endeffector_change_promp['name'] = f'metaworld:{_task}'
|
||||
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_goal_and_endeffector_change_promp
|
||||
id=f'metaworld/{_task}',
|
||||
register_step_based=False,
|
||||
mp_wrapper=goal_endeffector_change_mp_wrapper.MPWrapper,
|
||||
add_mp_types=['ProMP', 'ProDMP'],
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# ProDMP
|
||||
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_and_endeffector_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||
kwargs_dict_goal_and_endeffector_change_prodmp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_goal_and_endeffector_change_prodmp['name'] = f'metaworld:{_task}'
|
||||
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_goal_and_endeffector_change_prodmp
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||
|
@ -6,12 +6,63 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
|
||||
|
||||
class BaseMetaworldMPWrapper(RawInterfaceWrapper):
|
||||
mp_config = {
|
||||
'inherit_defaults': False,
|
||||
'ProMP': {
|
||||
'wrappers': [],
|
||||
'trajectory_generator_kwargs': {
|
||||
'trajectory_generator_type': 'promp',
|
||||
'weights_scale': 10,
|
||||
},
|
||||
'phase_generator_kwargs': {
|
||||
'phase_generator_type': 'linear'
|
||||
},
|
||||
'controller_kwargs': {
|
||||
'controller_type': 'metaworld',
|
||||
},
|
||||
'basis_generator_kwargs': {
|
||||
'basis_generator_type': 'zero_rbf',
|
||||
'num_basis': 5,
|
||||
'num_basis_zero_start': 1
|
||||
},
|
||||
'black_box_kwargs': {
|
||||
'condition_on_desired': False,
|
||||
},
|
||||
},
|
||||
'DMP': {},
|
||||
'ProDMP': {
|
||||
'wrappers': [],
|
||||
'trajectory_generator_kwargs': {
|
||||
'trajectory_generator_type': 'prodmp',
|
||||
'auto_scale_basis': True,
|
||||
'weights_scale': 10,
|
||||
# 'goal_scale': 0.,
|
||||
'disable_goal': True,
|
||||
},
|
||||
'phase_generator_kwargs': {
|
||||
'phase_generator_type': 'exp',
|
||||
# 'alpha_phase' : 3,
|
||||
},
|
||||
'controller_kwargs': {
|
||||
'controller_type': 'metaworld',
|
||||
},
|
||||
'basis_generator_kwargs': {
|
||||
'basis_generator_type': 'prodmp',
|
||||
'num_basis': 5,
|
||||
'alpha': 10
|
||||
},
|
||||
'black_box_kwargs': {
|
||||
'condition_on_desired': False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||
r_close = self.env.data.get_joint_qpos("r_close")
|
||||
r_close = self.env.data.get_joint_qpos('r_close')
|
||||
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
|
||||
|
||||
@property
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return np.zeros(4, )
|
||||
# raise NotImplementedError("Velocity cannot be retrieved.")
|
||||
# raise NotImplementedError('Velocity cannot be retrieved.')
|
||||
|
Loading…
Reference in New Issue
Block a user