diff --git a/alr_envs/__init__.py b/alr_envs/__init__.py index e43e3b1..30fa7b8 100644 --- a/alr_envs/__init__.py +++ b/alr_envs/__init__.py @@ -1,5 +1,5 @@ from alr_envs import dmc, meta, open_ai -from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_rank +from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_promp_env, make_rank from alr_envs.utils import make_dmc # Convenience function for all MP environments diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index d9843c0..627625d 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -10,7 +10,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv from .mujoco.reacher.alr_reacher import ALRReacherEnv from .mujoco.reacher.balancing import BalancingEnv -ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []} +ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []} # Classic Control ## Simple Reacher @@ -335,6 +335,25 @@ for _v in _versions: ) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) + _env_id = f'HoleReacherProMP-{_v}' + register( + id=_env_id, + entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', + kwargs={ + "name": f"alr_envs:HoleReacher-{_v}", + "wrappers": [classic_control.hole_reacher.MPWrapper], + "mp_kwargs": { + "num_dof": 5, + "num_basis": 5, + "duration": 2, + "policy_type": "velocity", + "weights_scale": 0.2, + "zero_start": True + } + } + ) + ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) + _env_id = f'HoleReacherDetPMP-{_v}' register( id=_env_id, diff --git a/alr_envs/dmc/__init__.py b/alr_envs/dmc/__init__.py index 17d1f7f..ca6469e 100644 --- a/alr_envs/dmc/__init__.py +++ b/alr_envs/dmc/__init__.py @@ -1,6 +1,6 @@ from . import manipulation, suite -ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []} +ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []} from gym.envs.registration import register diff --git a/alr_envs/examples/examples_motion_primitives.py b/alr_envs/examples/examples_motion_primitives.py index 6decdb1..de365b7 100644 --- a/alr_envs/examples/examples_motion_primitives.py +++ b/alr_envs/examples/examples_motion_primitives.py @@ -147,10 +147,13 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): if __name__ == '__main__': - render = False + render = True # DMP example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render) + # ProMP + example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render) + # DetProMP example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render) diff --git a/alr_envs/meta/__init__.py b/alr_envs/meta/__init__.py index fa63c94..9db0689 100644 --- a/alr_envs/meta/__init__.py +++ b/alr_envs/meta/__init__.py @@ -3,7 +3,7 @@ from gym import register from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \ object_change_mp_wrapper -ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []} +ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []} # MetaWorld diff --git a/alr_envs/open_ai/__init__.py b/alr_envs/open_ai/__init__.py index 51dd712..63083ca 100644 --- a/alr_envs/open_ai/__init__.py +++ b/alr_envs/open_ai/__init__.py @@ -3,7 +3,7 @@ from gym.wrappers import FlattenObservation from . import classic_control, mujoco, robotics -ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []} +ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []} # Short Continuous Mountain Car register( diff --git a/alr_envs/utils/make_env_helpers.py b/alr_envs/utils/make_env_helpers.py index fc73b05..19f54d6 100644 --- a/alr_envs/utils/make_env_helpers.py +++ b/alr_envs/utils/make_env_helpers.py @@ -7,6 +7,7 @@ from gym.envs.registration import EnvSpec from mp_env_api import MPEnvWrapper from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper +from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs): @@ -132,6 +133,26 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs return DmpWrapper(_env, **mp_kwargs) +def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs): + """ + This can also be used standalone for manually building a custom ProMP environment. + Args: + env_id: base_env_name, + wrappers: list of wrappers (at least an MPEnvWrapper), + mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int} + + Returns: ProMP wrapped gym env + + """ + _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) + + _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) + + _verify_dof(_env, mp_kwargs.get("num_dof")) + + return ProMPWrapper(_env, **mp_kwargs) + + def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs): """ This can also be used standalone for manually building a custom Det ProMP environment. @@ -140,7 +161,7 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa wrappers: list of wrappers (at least an MPEnvWrapper), mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int} - Returns: DMP wrapped gym env + Returns: Det ProMP wrapped gym env """ _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) @@ -171,6 +192,26 @@ def make_dmp_env_helper(**kwargs): mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs) +def make_promp_env_helper(**kwargs): + """ + Helper function for registering ProMP gym environments. + This can also be used standalone for manually building a custom ProMP environment. + Args: + **kwargs: expects at least the following: + { + "name": base_env_name, + "wrappers": list of wrappers (at least an MPEnvWrapper), + "mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int} + } + + Returns: ProMP wrapped gym env + + """ + seed = kwargs.pop("seed", None) + return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed, + mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs) + + def make_detpmp_env_helper(**kwargs): """ Helper function for registering ProMP gym environments.