Merge branch 'promp_wrapper' into reacher_env_cleanup

This commit is contained in:
Maximilian Huettenrauch 2021-11-15 10:55:00 +01:00
commit 786d481c88
7 changed files with 70 additions and 7 deletions

View File

@ -1,5 +1,5 @@
from alr_envs import dmc, meta, open_ai from alr_envs import dmc, meta, open_ai
from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_rank from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_promp_env, make_rank
from alr_envs.utils import make_dmc from alr_envs.utils import make_dmc
# Convenience function for all MP environments # Convenience function for all MP environments

View File

@ -9,7 +9,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
from .mujoco.reacher.alr_reacher import ALRReacherEnv from .mujoco.reacher.alr_reacher import ALRReacherEnv
from .mujoco.reacher.balancing import BalancingEnv from .mujoco.reacher.balancing import BalancingEnv
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []} ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
# Classic Control # Classic Control
## Simple Reacher ## Simple Reacher
@ -308,6 +308,25 @@ for _v in _versions:
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'HoleReacherProMP-{_v}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": f"alr_envs:HoleReacher-{_v}",
"wrappers": [classic_control.hole_reacher.MPWrapper],
"mp_kwargs": {
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True
}
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_env_id = f'HoleReacherDetPMP-{_v}' _env_id = f'HoleReacherDetPMP-{_v}'
register( register(
id=_env_id, id=_env_id,

View File

@ -1,6 +1,6 @@
from . import manipulation, suite from . import manipulation, suite
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []} ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
from gym.envs.registration import register from gym.envs.registration import register

View File

@ -147,10 +147,13 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__': if __name__ == '__main__':
render = False render = True
# DMP # DMP
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render) example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
# ProMP
example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render)
# DetProMP # DetProMP
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render) example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)

View File

@ -3,7 +3,7 @@ from gym import register
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \ from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
object_change_mp_wrapper object_change_mp_wrapper
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []} ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
# MetaWorld # MetaWorld

View File

@ -3,7 +3,7 @@ from gym.wrappers import FlattenObservation
from . import classic_control, mujoco, robotics from . import classic_control, mujoco, robotics
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []} ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
# Short Continuous Mountain Car # Short Continuous Mountain Car
register( register(

View File

@ -7,6 +7,7 @@ from gym.envs.registration import EnvSpec
from mp_env_api import MPEnvWrapper from mp_env_api import MPEnvWrapper
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs): def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
@ -132,6 +133,26 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs
return DmpWrapper(_env, **mp_kwargs) return DmpWrapper(_env, **mp_kwargs)
def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
"""
This can also be used standalone for manually building a custom ProMP environment.
Args:
env_id: base_env_name,
wrappers: list of wrappers (at least an MPEnvWrapper),
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
Returns: ProMP wrapped gym env
"""
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
_verify_dof(_env, mp_kwargs.get("num_dof"))
return ProMPWrapper(_env, **mp_kwargs)
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs): def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
""" """
This can also be used standalone for manually building a custom Det ProMP environment. This can also be used standalone for manually building a custom Det ProMP environment.
@ -140,7 +161,7 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa
wrappers: list of wrappers (at least an MPEnvWrapper), wrappers: list of wrappers (at least an MPEnvWrapper),
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int} mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
Returns: DMP wrapped gym env Returns: Det ProMP wrapped gym env
""" """
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
@ -171,6 +192,26 @@ def make_dmp_env_helper(**kwargs):
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs) mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
def make_promp_env_helper(**kwargs):
"""
Helper function for registering ProMP gym environments.
This can also be used standalone for manually building a custom ProMP environment.
Args:
**kwargs: expects at least the following:
{
"name": base_env_name,
"wrappers": list of wrappers (at least an MPEnvWrapper),
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
}
Returns: ProMP wrapped gym env
"""
seed = kwargs.pop("seed", None)
return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
def make_detpmp_env_helper(**kwargs): def make_detpmp_env_helper(**kwargs):
""" """
Helper function for registering ProMP gym environments. Helper function for registering ProMP gym environments.