added promp wrapped environments
This commit is contained in:
parent
ebca59b4bd
commit
a1a5da3f1e
@ -1,5 +1,5 @@
|
|||||||
from alr_envs import dmc, meta, open_ai
|
from alr_envs import dmc, meta, open_ai
|
||||||
from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_rank
|
from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_promp_env, make_rank
|
||||||
from alr_envs.utils import make_dmc
|
from alr_envs.utils import make_dmc
|
||||||
|
|
||||||
# Convenience function for all MP environments
|
# Convenience function for all MP environments
|
||||||
|
@ -10,7 +10,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
|||||||
from .mujoco.reacher.alr_reacher import ALRReacherEnv
|
from .mujoco.reacher.alr_reacher import ALRReacherEnv
|
||||||
from .mujoco.reacher.balancing import BalancingEnv
|
from .mujoco.reacher.balancing import BalancingEnv
|
||||||
|
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||||
|
|
||||||
# Classic Control
|
# Classic Control
|
||||||
## Simple Reacher
|
## Simple Reacher
|
||||||
@ -335,6 +335,25 @@ for _v in _versions:
|
|||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
|
_env_id = f'HoleReacherProMP-{_v}'
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": f"alr_envs:HoleReacher-{_v}",
|
||||||
|
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 5,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'HoleReacherDetPMP-{_v}'
|
_env_id = f'HoleReacherDetPMP-{_v}'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from . import manipulation, suite
|
from . import manipulation, suite
|
||||||
|
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||||
|
|
||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
|
|
||||||
|
@ -147,10 +147,13 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = False
|
render = True
|
||||||
# DMP
|
# DMP
|
||||||
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
|
# ProMP
|
||||||
|
example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# DetProMP
|
# DetProMP
|
||||||
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
|
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from gym import register
|
|||||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||||
object_change_mp_wrapper
|
object_change_mp_wrapper
|
||||||
|
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
|
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||||
|
|
||||||
# MetaWorld
|
# MetaWorld
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from gym.wrappers import FlattenObservation
|
|||||||
|
|
||||||
from . import classic_control, mujoco, robotics
|
from . import classic_control, mujoco, robotics
|
||||||
|
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||||
|
|
||||||
# Short Continuous Mountain Car
|
# Short Continuous Mountain Car
|
||||||
register(
|
register(
|
||||||
|
@ -7,6 +7,7 @@ from gym.envs.registration import EnvSpec
|
|||||||
from mp_env_api import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||||
|
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
|
||||||
|
|
||||||
|
|
||||||
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
||||||
@ -132,6 +133,26 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs
|
|||||||
return DmpWrapper(_env, **mp_kwargs)
|
return DmpWrapper(_env, **mp_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
||||||
|
"""
|
||||||
|
This can also be used standalone for manually building a custom ProMP environment.
|
||||||
|
Args:
|
||||||
|
env_id: base_env_name,
|
||||||
|
wrappers: list of wrappers (at least an MPEnvWrapper),
|
||||||
|
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
||||||
|
|
||||||
|
Returns: ProMP wrapped gym env
|
||||||
|
|
||||||
|
"""
|
||||||
|
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||||
|
|
||||||
|
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||||
|
|
||||||
|
_verify_dof(_env, mp_kwargs.get("num_dof"))
|
||||||
|
|
||||||
|
return ProMPWrapper(_env, **mp_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
||||||
"""
|
"""
|
||||||
This can also be used standalone for manually building a custom Det ProMP environment.
|
This can also be used standalone for manually building a custom Det ProMP environment.
|
||||||
@ -140,7 +161,7 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa
|
|||||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
wrappers: list of wrappers (at least an MPEnvWrapper),
|
||||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
||||||
|
|
||||||
Returns: DMP wrapped gym env
|
Returns: Det ProMP wrapped gym env
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||||
@ -171,6 +192,26 @@ def make_dmp_env_helper(**kwargs):
|
|||||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_promp_env_helper(**kwargs):
|
||||||
|
"""
|
||||||
|
Helper function for registering ProMP gym environments.
|
||||||
|
This can also be used standalone for manually building a custom ProMP environment.
|
||||||
|
Args:
|
||||||
|
**kwargs: expects at least the following:
|
||||||
|
{
|
||||||
|
"name": base_env_name,
|
||||||
|
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
||||||
|
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns: ProMP wrapped gym env
|
||||||
|
|
||||||
|
"""
|
||||||
|
seed = kwargs.pop("seed", None)
|
||||||
|
return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
||||||
|
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_detpmp_env_helper(**kwargs):
|
def make_detpmp_env_helper(**kwargs):
|
||||||
"""
|
"""
|
||||||
Helper function for registering ProMP gym environments.
|
Helper function for registering ProMP gym environments.
|
||||||
|
Loading…
Reference in New Issue
Block a user