Added more tests
This commit is contained in:
parent
915ffbe928
commit
00f622e913
@ -1,4 +1,5 @@
|
||||
from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator
|
||||
from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator, \
|
||||
ProDMPBasisGenerator
|
||||
from mp_pytorch.phase_gn import PhaseGenerator
|
||||
|
||||
ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"]
|
||||
@ -10,6 +11,10 @@ def get_basis_generator(basis_generator_type: str, phase_generator: PhaseGenerat
|
||||
return NormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
||||
elif basis_generator_type == "zero_rbf":
|
||||
return ZeroPaddingNormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
||||
elif basis_generator_type == "prodmp":
|
||||
from mp_pytorch.phase_gn import ExpDecayPhaseGenerator
|
||||
assert isinstance(phase_generator, ExpDecayPhaseGenerator)
|
||||
return ProDMPBasisGenerator(phase_generator, **kwargs)
|
||||
elif basis_generator_type == "rhythmic":
|
||||
raise NotImplementedError()
|
||||
# return RhythmicBasisGenerator(phase_generator, **kwargs)
|
||||
|
@ -2,7 +2,7 @@ from copy import deepcopy
|
||||
|
||||
from . import manipulation, suite
|
||||
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
from gym.envs.registration import register
|
||||
|
||||
|
@ -17,7 +17,7 @@ from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPER
|
||||
from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
|
||||
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
||||
|
||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
DEFAULT_BB_DICT_ProMP = {
|
||||
"name": 'EnvName',
|
||||
@ -205,7 +205,6 @@ register(
|
||||
max_episode_steps=FIXED_RELEASE_STEP,
|
||||
)
|
||||
|
||||
|
||||
# movement Primitive Environments
|
||||
|
||||
## Simple Reacher
|
||||
|
@ -1,6 +1,6 @@
|
||||
# MetaWorld Wrappers
|
||||
|
||||
These are the Environment Wrappers for selected [Metaworld](https://meta-world.github.io/) environments in order to use our Motion Primitive gym interface with them.
|
||||
These are the Environment Wrappers for selected [Metaworld](https://meta-world.github.io/) environments in order to use our Movement Primitive gym interface with them.
|
||||
All Metaworld environments have a 39 dimensional observation space with the same structure. The tasks differ only in the objective and the initial observations that are randomized.
|
||||
Unused observations are zeroed out. E.g. for `Button-Press-v2` the observation mask looks the following:
|
||||
```python
|
||||
|
@ -5,7 +5,7 @@ from gym import register
|
||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||
object_change_mp_wrapper
|
||||
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
# MetaWorld
|
||||
|
||||
@ -28,11 +28,31 @@ DEFAULT_BB_DICT_ProMP = {
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_BB_DICT_ProDMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'prodmp'
|
||||
},
|
||||
"phase_generator_kwargs": {
|
||||
'phase_generator_type': 'exp'
|
||||
},
|
||||
"controller_kwargs": {
|
||||
'controller_type': 'metaworld',
|
||||
},
|
||||
"basis_generator_kwargs": {
|
||||
'basis_generator_type': 'prodmp',
|
||||
'num_basis': 5
|
||||
}
|
||||
}
|
||||
|
||||
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||
"plate-slide-side-v2", "plate-slide-back-side-v2"]
|
||||
for _task in _goal_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
|
||||
# ProMP
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
kwargs_dict_goal_change_promp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
||||
@ -45,10 +65,25 @@ for _task in _goal_change_envs:
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# ProDMP
|
||||
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||
kwargs_dict_goal_change_prodmp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_goal_change_prodmp['name'] = f'metaworld:{_task}'
|
||||
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_goal_change_prodmp
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||
|
||||
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
||||
for _task in _object_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
|
||||
# ProMP
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
kwargs_dict_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
kwargs_dict_object_change_promp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
||||
@ -60,6 +95,18 @@ for _task in _object_change_envs:
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# ProDMP
|
||||
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||
kwargs_dict_object_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||
kwargs_dict_object_change_prodmp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_object_change_prodmp['name'] = f'metaworld:{_task}'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_object_change_prodmp
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||
|
||||
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
||||
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
||||
"coffee-push-v2", "dial-turn-v2", "disassemble-v2", "door-close-v2",
|
||||
@ -74,6 +121,8 @@ _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press
|
||||
for _task in _goal_and_object_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
|
||||
# ProMP
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_and_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
kwargs_dict_goal_and_object_change_promp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
||||
@ -86,10 +135,26 @@ for _task in _goal_and_object_change_envs:
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# ProDMP
|
||||
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_and_object_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||
kwargs_dict_goal_and_object_change_prodmp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_goal_and_object_change_prodmp['name'] = f'metaworld:{_task}'
|
||||
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_goal_and_object_change_prodmp
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||
|
||||
|
||||
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
||||
for _task in _goal_and_endeffector_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
|
||||
# ProMP
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_and_endeffector_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||
kwargs_dict_goal_and_endeffector_change_promp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
||||
@ -101,3 +166,16 @@ for _task in _goal_and_endeffector_change_envs:
|
||||
kwargs=kwargs_dict_goal_and_endeffector_change_promp
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# ProDMP
|
||||
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||
kwargs_dict_goal_and_endeffector_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||
kwargs_dict_goal_and_endeffector_change_prodmp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
||||
kwargs_dict_goal_and_endeffector_change_prodmp['name'] = f'metaworld:{_task}'
|
||||
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||
kwargs=kwargs_dict_goal_and_endeffector_change_prodmp
|
||||
)
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||
|
@ -5,7 +5,7 @@ from gym import register
|
||||
from . import mujoco
|
||||
from .deprecated_needs_gym_robotics import robotics
|
||||
|
||||
ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
DEFAULT_BB_DICT_ProMP = {
|
||||
"name": 'EnvName',
|
||||
|
Loading…
Reference in New Issue
Block a user