added promp wrapped environments
This commit is contained in:
		
							parent
							
								
									ebca59b4bd
								
							
						
					
					
						commit
						a1a5da3f1e
					
				@ -1,5 +1,5 @@
 | 
				
			|||||||
from alr_envs import dmc, meta, open_ai
 | 
					from alr_envs import dmc, meta, open_ai
 | 
				
			||||||
from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_rank
 | 
					from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_promp_env, make_rank
 | 
				
			||||||
from alr_envs.utils import make_dmc
 | 
					from alr_envs.utils import make_dmc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Convenience function for all MP environments
 | 
					# Convenience function for all MP environments
 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
 | 
				
			|||||||
from .mujoco.reacher.alr_reacher import ALRReacherEnv
 | 
					from .mujoco.reacher.alr_reacher import ALRReacherEnv
 | 
				
			||||||
from .mujoco.reacher.balancing import BalancingEnv
 | 
					from .mujoco.reacher.balancing import BalancingEnv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
 | 
					ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Classic Control
 | 
					# Classic Control
 | 
				
			||||||
## Simple Reacher
 | 
					## Simple Reacher
 | 
				
			||||||
@ -335,6 +335,25 @@ for _v in _versions:
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
    ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
 | 
					    ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _env_id = f'HoleReacherProMP-{_v}'
 | 
				
			||||||
 | 
					    register(
 | 
				
			||||||
 | 
					        id=_env_id,
 | 
				
			||||||
 | 
					        entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
 | 
				
			||||||
 | 
					        kwargs={
 | 
				
			||||||
 | 
					            "name": f"alr_envs:HoleReacher-{_v}",
 | 
				
			||||||
 | 
					            "wrappers": [classic_control.hole_reacher.MPWrapper],
 | 
				
			||||||
 | 
					            "mp_kwargs": {
 | 
				
			||||||
 | 
					                "num_dof": 5,
 | 
				
			||||||
 | 
					                "num_basis": 5,
 | 
				
			||||||
 | 
					                "duration": 2,
 | 
				
			||||||
 | 
					                "policy_type": "velocity",
 | 
				
			||||||
 | 
					                "weights_scale": 0.2,
 | 
				
			||||||
 | 
					                "zero_start": True
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _env_id = f'HoleReacherDetPMP-{_v}'
 | 
					    _env_id = f'HoleReacherDetPMP-{_v}'
 | 
				
			||||||
    register(
 | 
					    register(
 | 
				
			||||||
        id=_env_id,
 | 
					        id=_env_id,
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,6 @@
 | 
				
			|||||||
from . import manipulation, suite
 | 
					from . import manipulation, suite
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
 | 
					ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from gym.envs.registration import register
 | 
					from gym.envs.registration import register
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -147,10 +147,13 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    render = False
 | 
					    render = True
 | 
				
			||||||
    # DMP
 | 
					    # DMP
 | 
				
			||||||
    example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
 | 
					    example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ProMP
 | 
				
			||||||
 | 
					    example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # DetProMP
 | 
					    # DetProMP
 | 
				
			||||||
    example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
 | 
					    example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,7 @@ from gym import register
 | 
				
			|||||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
 | 
					from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
 | 
				
			||||||
    object_change_mp_wrapper
 | 
					    object_change_mp_wrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
 | 
					ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# MetaWorld
 | 
					# MetaWorld
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,7 @@ from gym.wrappers import FlattenObservation
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from . import classic_control, mujoco, robotics
 | 
					from . import classic_control, mujoco, robotics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "DetPMP": []}
 | 
					ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Short Continuous Mountain Car
 | 
					# Short Continuous Mountain Car
 | 
				
			||||||
register(
 | 
					register(
 | 
				
			||||||
 | 
				
			|||||||
@ -7,6 +7,7 @@ from gym.envs.registration import EnvSpec
 | 
				
			|||||||
from mp_env_api import MPEnvWrapper
 | 
					from mp_env_api import MPEnvWrapper
 | 
				
			||||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
 | 
					from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
 | 
				
			||||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
 | 
					from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
 | 
				
			||||||
 | 
					from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
 | 
					def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
 | 
				
			||||||
@ -132,6 +133,26 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs
 | 
				
			|||||||
    return DmpWrapper(_env, **mp_kwargs)
 | 
					    return DmpWrapper(_env, **mp_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    This can also be used standalone for manually building a custom ProMP environment.
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        env_id: base_env_name,
 | 
				
			||||||
 | 
					        wrappers: list of wrappers (at least an MPEnvWrapper),
 | 
				
			||||||
 | 
					        mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns: ProMP wrapped gym env
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _verify_dof(_env, mp_kwargs.get("num_dof"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return ProMPWrapper(_env, **mp_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
 | 
					def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    This can also be used standalone for manually building a custom Det ProMP environment.
 | 
					    This can also be used standalone for manually building a custom Det ProMP environment.
 | 
				
			||||||
@ -140,7 +161,7 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa
 | 
				
			|||||||
        wrappers: list of wrappers (at least an MPEnvWrapper),
 | 
					        wrappers: list of wrappers (at least an MPEnvWrapper),
 | 
				
			||||||
        mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
 | 
					        mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Returns: DMP wrapped gym env
 | 
					    Returns: Det ProMP wrapped gym env
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
 | 
					    _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
 | 
				
			||||||
@ -171,6 +192,26 @@ def make_dmp_env_helper(**kwargs):
 | 
				
			|||||||
                        mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
 | 
					                        mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def make_promp_env_helper(**kwargs):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Helper function for registering ProMP gym environments.
 | 
				
			||||||
 | 
					    This can also be used standalone for manually building a custom ProMP environment.
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        **kwargs: expects at least the following:
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					        "name": base_env_name,
 | 
				
			||||||
 | 
					        "wrappers": list of wrappers (at least an MPEnvWrapper),
 | 
				
			||||||
 | 
					        "mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns: ProMP wrapped gym env
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    seed = kwargs.pop("seed", None)
 | 
				
			||||||
 | 
					    return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
 | 
				
			||||||
 | 
					                          mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def make_detpmp_env_helper(**kwargs):
 | 
					def make_detpmp_env_helper(**kwargs):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Helper function for registering ProMP gym environments.
 | 
					    Helper function for registering ProMP gym environments.
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user