Added more tests
This commit is contained in:
parent
915ffbe928
commit
00f622e913
@ -1,4 +1,5 @@
|
|||||||
from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator
|
from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator, \
|
||||||
|
ProDMPBasisGenerator
|
||||||
from mp_pytorch.phase_gn import PhaseGenerator
|
from mp_pytorch.phase_gn import PhaseGenerator
|
||||||
|
|
||||||
ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"]
|
ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"]
|
||||||
@ -10,6 +11,10 @@ def get_basis_generator(basis_generator_type: str, phase_generator: PhaseGenerat
|
|||||||
return NormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
return NormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
||||||
elif basis_generator_type == "zero_rbf":
|
elif basis_generator_type == "zero_rbf":
|
||||||
return ZeroPaddingNormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
return ZeroPaddingNormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
||||||
|
elif basis_generator_type == "prodmp":
|
||||||
|
from mp_pytorch.phase_gn import ExpDecayPhaseGenerator
|
||||||
|
assert isinstance(phase_generator, ExpDecayPhaseGenerator)
|
||||||
|
return ProDMPBasisGenerator(phase_generator, **kwargs)
|
||||||
elif basis_generator_type == "rhythmic":
|
elif basis_generator_type == "rhythmic":
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
# return RhythmicBasisGenerator(phase_generator, **kwargs)
|
# return RhythmicBasisGenerator(phase_generator, **kwargs)
|
||||||
|
@ -2,7 +2,7 @@ from copy import deepcopy
|
|||||||
|
|
||||||
from . import manipulation, suite
|
from . import manipulation, suite
|
||||||
|
|
||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||||
|
|
||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPER
|
|||||||
from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
|
from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
|
||||||
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
||||||
|
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||||
|
|
||||||
DEFAULT_BB_DICT_ProMP = {
|
DEFAULT_BB_DICT_ProMP = {
|
||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
@ -205,7 +205,6 @@ register(
|
|||||||
max_episode_steps=FIXED_RELEASE_STEP,
|
max_episode_steps=FIXED_RELEASE_STEP,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# movement Primitive Environments
|
# movement Primitive Environments
|
||||||
|
|
||||||
## Simple Reacher
|
## Simple Reacher
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# MetaWorld Wrappers
|
# MetaWorld Wrappers
|
||||||
|
|
||||||
These are the Environment Wrappers for selected [Metaworld](https://meta-world.github.io/) environments in order to use our Motion Primitive gym interface with them.
|
These are the Environment Wrappers for selected [Metaworld](https://meta-world.github.io/) environments in order to use our Movement Primitive gym interface with them.
|
||||||
All Metaworld environments have a 39 dimensional observation space with the same structure. The tasks differ only in the objective and the initial observations that are randomized.
|
All Metaworld environments have a 39 dimensional observation space with the same structure. The tasks differ only in the objective and the initial observations that are randomized.
|
||||||
Unused observations are zeroed out. E.g. for `Button-Press-v2` the observation mask looks the following:
|
Unused observations are zeroed out. E.g. for `Button-Press-v2` the observation mask looks the following:
|
||||||
```python
|
```python
|
||||||
|
@ -5,7 +5,7 @@ from gym import register
|
|||||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||||
object_change_mp_wrapper
|
object_change_mp_wrapper
|
||||||
|
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||||
|
|
||||||
# MetaWorld
|
# MetaWorld
|
||||||
|
|
||||||
@ -28,11 +28,31 @@ DEFAULT_BB_DICT_ProMP = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DEFAULT_BB_DICT_ProDMP = {
|
||||||
|
"name": 'EnvName',
|
||||||
|
"wrappers": [],
|
||||||
|
"trajectory_generator_kwargs": {
|
||||||
|
'trajectory_generator_type': 'prodmp'
|
||||||
|
},
|
||||||
|
"phase_generator_kwargs": {
|
||||||
|
'phase_generator_type': 'exp'
|
||||||
|
},
|
||||||
|
"controller_kwargs": {
|
||||||
|
'controller_type': 'metaworld',
|
||||||
|
},
|
||||||
|
"basis_generator_kwargs": {
|
||||||
|
'basis_generator_type': 'prodmp',
|
||||||
|
'num_basis': 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||||
"plate-slide-side-v2", "plate-slide-back-side-v2"]
|
"plate-slide-side-v2", "plate-slide-back-side-v2"]
|
||||||
for _task in _goal_change_envs:
|
for _task in _goal_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
|
|
||||||
|
# ProMP
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
kwargs_dict_goal_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_goal_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_goal_change_promp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
kwargs_dict_goal_change_promp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
||||||
@ -45,10 +65,25 @@ for _task in _goal_change_envs:
|
|||||||
)
|
)
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
# ProDMP
|
||||||
|
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_goal_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
|
kwargs_dict_goal_change_prodmp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_goal_change_prodmp['name'] = f'metaworld:{_task}'
|
||||||
|
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
kwargs=kwargs_dict_goal_change_prodmp
|
||||||
|
)
|
||||||
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
|
||||||
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
||||||
for _task in _object_change_envs:
|
for _task in _object_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
|
|
||||||
|
# ProMP
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
kwargs_dict_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_object_change_promp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
kwargs_dict_object_change_promp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
||||||
@ -60,6 +95,18 @@ for _task in _object_change_envs:
|
|||||||
)
|
)
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
# ProDMP
|
||||||
|
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_object_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
|
kwargs_dict_object_change_prodmp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_object_change_prodmp['name'] = f'metaworld:{_task}'
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
kwargs=kwargs_dict_object_change_prodmp
|
||||||
|
)
|
||||||
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
|
||||||
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
||||||
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
||||||
"coffee-push-v2", "dial-turn-v2", "disassemble-v2", "door-close-v2",
|
"coffee-push-v2", "dial-turn-v2", "disassemble-v2", "door-close-v2",
|
||||||
@ -74,6 +121,8 @@ _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press
|
|||||||
for _task in _goal_and_object_change_envs:
|
for _task in _goal_and_object_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
|
|
||||||
|
# ProMP
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
kwargs_dict_goal_and_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_goal_and_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_goal_and_object_change_promp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
kwargs_dict_goal_and_object_change_promp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
||||||
@ -86,10 +135,26 @@ for _task in _goal_and_object_change_envs:
|
|||||||
)
|
)
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
# ProDMP
|
||||||
|
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_goal_and_object_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
|
kwargs_dict_goal_and_object_change_prodmp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_goal_and_object_change_prodmp['name'] = f'metaworld:{_task}'
|
||||||
|
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
kwargs=kwargs_dict_goal_and_object_change_prodmp
|
||||||
|
)
|
||||||
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
|
||||||
|
|
||||||
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
||||||
for _task in _goal_and_endeffector_change_envs:
|
for _task in _goal_and_endeffector_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
|
|
||||||
|
# ProMP
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
kwargs_dict_goal_and_endeffector_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_goal_and_endeffector_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_goal_and_endeffector_change_promp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
kwargs_dict_goal_and_endeffector_change_promp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
||||||
@ -101,3 +166,16 @@ for _task in _goal_and_endeffector_change_envs:
|
|||||||
kwargs=kwargs_dict_goal_and_endeffector_change_promp
|
kwargs=kwargs_dict_goal_and_endeffector_change_promp
|
||||||
)
|
)
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
# ProDMP
|
||||||
|
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_goal_and_endeffector_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
|
kwargs_dict_goal_and_endeffector_change_prodmp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_goal_and_endeffector_change_prodmp['name'] = f'metaworld:{_task}'
|
||||||
|
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
kwargs=kwargs_dict_goal_and_endeffector_change_prodmp
|
||||||
|
)
|
||||||
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
@ -5,7 +5,7 @@ from gym import register
|
|||||||
from . import mujoco
|
from . import mujoco
|
||||||
from .deprecated_needs_gym_robotics import robotics
|
from .deprecated_needs_gym_robotics import robotics
|
||||||
|
|
||||||
ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||||
|
|
||||||
DEFAULT_BB_DICT_ProMP = {
|
DEFAULT_BB_DICT_ProMP = {
|
||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
|
Loading…
Reference in New Issue
Block a user