added promp wrapped environments
This commit is contained in:
parent
ac969d490a
commit
86f894c1ff
@ -1,5 +1,5 @@
|
||||
from alr_envs import dmc, meta, open_ai
|
||||
from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_rank
|
||||
from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_promp_env, make_rank
|
||||
from alr_envs.utils import make_dmc
|
||||
|
||||
# Convenience function for all MP environments
|
||||
|
@ -9,7 +9,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
||||
from .mujoco.reacher.alr_reacher import ALRReacherEnv
|
||||
from .mujoco.reacher.balancing import BalancingEnv
|
||||
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||
|
||||
# Classic Control
|
||||
## Simple Reacher
|
||||
@ -308,6 +308,25 @@ for _v in _versions:
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
_env_id = f'HoleReacherProMP-{_v}'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": f"alr_envs:HoleReacher-{_v}",
|
||||
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
_env_id = f'HoleReacherDetPMP-{_v}'
|
||||
register(
|
||||
id=_env_id,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from . import manipulation, suite
|
||||
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||
|
||||
from gym.envs.registration import register
|
||||
|
||||
|
@ -147,10 +147,13 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = False
|
||||
render = True
|
||||
# DMP
|
||||
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||
|
||||
# ProMP
|
||||
example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render)
|
||||
|
||||
# DetProMP
|
||||
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
|
||||
|
||||
|
@ -3,7 +3,7 @@ from gym import register
|
||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||
object_change_mp_wrapper
|
||||
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||
|
||||
# MetaWorld
|
||||
|
||||
|
@ -3,7 +3,7 @@ from gym.wrappers import FlattenObservation
|
||||
|
||||
from . import classic_control, mujoco, robotics
|
||||
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||
|
||||
# Short Continuous Mountain Car
|
||||
register(
|
||||
|
@ -7,6 +7,7 @@ from gym.envs.registration import EnvSpec
|
||||
from mp_env_api import MPEnvWrapper
|
||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
|
||||
|
||||
|
||||
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
||||
@ -132,6 +133,26 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs
|
||||
return DmpWrapper(_env, **mp_kwargs)
|
||||
|
||||
|
||||
def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
||||
"""
|
||||
This can also be used standalone for manually building a custom ProMP environment.
|
||||
Args:
|
||||
env_id: base_env_name,
|
||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
||||
|
||||
Returns: ProMP wrapped gym env
|
||||
|
||||
"""
|
||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||
|
||||
_verify_dof(_env, mp_kwargs.get("num_dof"))
|
||||
|
||||
return ProMPWrapper(_env, **mp_kwargs)
|
||||
|
||||
|
||||
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
||||
"""
|
||||
This can also be used standalone for manually building a custom Det ProMP environment.
|
||||
@ -140,7 +161,7 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa
|
||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
||||
|
||||
Returns: DMP wrapped gym env
|
||||
Returns: Det ProMP wrapped gym env
|
||||
|
||||
"""
|
||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
@ -171,6 +192,26 @@ def make_dmp_env_helper(**kwargs):
|
||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||
|
||||
|
||||
def make_promp_env_helper(**kwargs):
|
||||
"""
|
||||
Helper function for registering ProMP gym environments.
|
||||
This can also be used standalone for manually building a custom ProMP environment.
|
||||
Args:
|
||||
**kwargs: expects at least the following:
|
||||
{
|
||||
"name": base_env_name,
|
||||
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
||||
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
|
||||
}
|
||||
|
||||
Returns: ProMP wrapped gym env
|
||||
|
||||
"""
|
||||
seed = kwargs.pop("seed", None)
|
||||
return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||
|
||||
|
||||
def make_detpmp_env_helper(**kwargs):
|
||||
"""
|
||||
Helper function for registering ProMP gym environments.
|
||||
|
Loading…
Reference in New Issue
Block a user