replaced all detpmp with promp

This commit is contained in:
Maximilian Huettenrauch 2021-11-30 16:11:32 +01:00
parent 655d52aa35
commit 8b88ce3476
20 changed files with 121 additions and 407 deletions

View File

@ -6,14 +6,14 @@ Besides, some custom environments we also provide support for the benchmark suit
[DeepMind Control](https://deepmind.com/research/publications/2020/dm-control-Software-and-Tasks-for-Continuous-Control)
(DMC), and [Metaworld](https://meta-world.github.io/). Custom (Mujoco) gym environment can be created according
to [this guide](https://github.com/openai/gym/blob/master/docs/creating-environments.md). Unlike existing libraries, we
further support to control agents with Dynamic Movement Primitives (DMPs) and Probabilistic Movement Primitives (DetPMP,
further support to control agents with Dynamic Movement Primitives (DMPs) and Probabilistic Movement Primitives (ProMP,
we only consider the mean usually).
## Motion Primitive Environments (Episodic environments)
Unlike step-based environments, motion primitive (MP) environments are closer related to stochastic search, black box
optimization and methods that often used in robotics. MP environments are trajectory-based and always execute a full
trajectory, which is generated by a Dynamic Motion Primitive (DMP) or a Probabilistic Motion Primitive (DetPMP). The
trajectory, which is generated by a Dynamic Motion Primitive (DMP) or a Probabilistic Motion Primitive (ProMP). The
generated trajectory is translated into individual step-wise actions by a controller. The exact choice of controller is,
however, dependent on the type of environment. We currently support position, velocity, and PD-Controllers for position,
velocity and torque control, respectively. The goal of all MP environments is still to learn a policy. Yet, an action
@ -82,7 +82,7 @@ trajectory.
```python
import alr_envs
env = alr_envs.make('HoleReacherDetPMP-v0', seed=1)
env = alr_envs.make('HoleReacherProMP-v0', seed=1)
# render() can be called once in the beginning with all necessary arguments. To turn it of again just call render(None).
env.render()
@ -96,7 +96,7 @@ for i in range(5):
```
To show all available environments, we provide some additional convenience. Each value will return a dictionary with two
keys `DMP` and `DetPMP` that store a list of available environment names.
keys `DMP` and `ProMP` that store a list of available environment names.
```python
import alr_envs
@ -194,7 +194,7 @@ mp_kwargs = {...}
kwargs = {...}
env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs)
# OR for a deterministic ProMP (other mp_kwargs are required):
# env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
rewards = 0
obs = env.reset()

View File

@ -1,5 +1,5 @@
from alr_envs import dmc, meta, open_ai
from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_promp_env, make_rank
from alr_envs.utils.make_env_helpers import make, make_dmp_env, make_promp_env, make_rank
from alr_envs.utils import make_dmc
# Convenience function for all MP environments

View File

@ -9,7 +9,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
from .mujoco.reacher.alr_reacher import ALRReacherEnv
from .mujoco.reacher.balancing import BalancingEnv
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
# Classic Control
## Simple Reacher
@ -213,8 +213,12 @@ for _v in _versions:
"duration": 2,
"alpha_phase": 2,
"learn_goal": True,
"policy_type": "velocity",
"policy_type": "motor",
"weights_scale": 50,
"policy_kwargs": {
"p_gains": .6,
"d_gains": .075
}
}
}
)
@ -233,33 +237,16 @@ for _v in _versions:
"duration": 2,
"policy_type": "motor",
"weights_scale": 1,
"zero_start": True
"zero_start": True,
"policy_kwargs": {
"p_gains": .6,
"d_gains": .075
}
}
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_env_id = f'{_name[0]}DetPMP-{_name[1]}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": f"alr_envs:{_v}",
"wrappers": [classic_control.simple_reacher.MPWrapper],
"mp_kwargs": {
"num_dof": 2 if "long" not in _v.lower() else 5,
"num_basis": 5,
"duration": 2,
"width": 0.025,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True
}
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
# Viapoint reacher
register(
id='ViaPointReacherDMP-v0',
@ -291,7 +278,7 @@ register(
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"policy_type": "motor",
"policy_type": "velocity",
"weights_scale": 1,
"zero_start": True
}
@ -299,26 +286,6 @@ register(
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
register(
id='ViaPointReacherDetPMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": "alr_envs:ViaPointReacher-v0",
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
"mp_kwargs": {
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"width": 0.025,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True
}
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ViaPointReacherDetPMP-v0")
## Hole Reacher
_versions = ["v0", "v1", "v2"]
for _v in _versions:
@ -363,23 +330,3 @@ for _v in _versions:
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_env_id = f'HoleReacherDetPMP-{_v}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={
"name": f"alr_envs:HoleReacher-{_v}",
"wrappers": [classic_control.hole_reacher.MPWrapper],
"mp_kwargs": {
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"width": 0.025,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True
}
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)

View File

@ -18,4 +18,4 @@
|`ALRBallInACupDMP-v0`| A DMP provides a trajectory for the `ALRBallInACup-v0` task. | 4000 | 35
|`ALRBallInACupGoalDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupGoal-v0` task. | 4000 | 35 | 3
[//]: |`HoleReacherDetPMP-v0`|
[//]: |`HoleReacherProMPP-v0`|

View File

@ -5,7 +5,6 @@ import matplotlib.pyplot as plt
import numpy as np
from gym.utils import seeding
from alr_envs.alr.classic_control.utils import check_self_collision
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv

View File

@ -1,107 +0,0 @@
from alr_envs.alr.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
def make_contextual_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBallInACupEnv(reward_type="contextual_goal")
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
env.seed(seed + rank)
return env
return _init
def _make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBallInACupEnv(reward_type="simple")
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.2, zero_start=True, zero_goal=True)
env.seed(seed + rank)
return env
return _init
def make_simple_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBallInACupEnv(reward_type="simple")
env = DetPMPWrapper(env, num_dof=3, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.25, zero_start=True, zero_goal=True, off=-0.1)
env.seed(seed + rank)
return env
return _init
def make_simple_dmp_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
_env = ALRBallInACupEnv(reward_type="simple")
_env = DmpWrapper(_env,
num_dof=3,
num_basis=5,
duration=3.5,
post_traj_time=4.5,
bandwidth_factor=2.5,
dt=_env.dt,
learn_goal=False,
alpha_phase=3,
start_pos=_env.start_pos[1::2],
final_pos=_env.start_pos[1::2],
policy_type="motor",
weights_scale=100,
)
_env.seed(seed + rank)
return _env
return _init

View File

@ -1,72 +0,0 @@
from alr_envs.utils.mps.detpmp_wrapper import DetPMPWrapper
from alr_envs.alr.mujoco.beerpong.beerpong import ALRBeerpongEnv
from alr_envs.alr.mujoco.beerpong.beerpong_simple import ALRBeerpongEnv as ALRBeerpongEnvSimple
def make_contextual_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBeerpongEnv()
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
env.seed(seed + rank)
return env
return _init
def _make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBeerpongEnvSimple()
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.25, zero_start=True, zero_goal=True)
env.seed(seed + rank)
return env
return _init
def make_simple_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBeerpongEnvSimple()
env = DetPMPWrapper(env, num_dof=3, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
env.seed(seed + rank)
return env
return _init

View File

@ -11,9 +11,9 @@ environments in order to use our Motion Primitive gym interface with them.
|Name| Description|Trajectory Horizon|Action Dimension|Context Dimension
|---|---|---|---|---|
|`dmc_ball_in_cup-catch_detpmp-v0`| A DetPmP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000 | 10 | 2
|`dmc_ball_in_cup-catch_promp-v0`| A ProMP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000 | 10 | 2
|`dmc_ball_in_cup-catch_dmp-v0`| A DMP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000| 10 | 2
|`dmc_reacher-easy_detpmp-v0`| A DetPmP wrapped version of the "easy" task for the "reacher" environment. | 1000 | 10 | 4
|`dmc_reacher-easy_promp-v0`| A ProMP wrapped version of the "easy" task for the "reacher" environment. | 1000 | 10 | 4
|`dmc_reacher-easy_dmp-v0`| A DMP wrapped version of the "easy" task for the "reacher" environment. | 1000| 10 | 4
|`dmc_reacher-hard_detpmp-v0`| A DetPmP wrapped version of the "hard" task for the "reacher" environment.| 1000 | 10 | 4
|`dmc_reacher-hard_promp-v0`| A ProMP wrapped version of the "hard" task for the "reacher" environment.| 1000 | 10 | 4
|`dmc_reacher-hard_dmp-v0`| A DMP wrapped version of the "hard" task for the "reacher" environment. | 1000 | 10 | 4

View File

@ -1,6 +1,6 @@
from . import manipulation, suite
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
from gym.envs.registration import register
@ -34,8 +34,8 @@ register(
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
register(
id=f'dmc_ball_in_cup-catch_detpmp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id=f'dmc_ball_in_cup-catch_promp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": f"ball_in_cup-catch",
"time_limit": 20,
@ -45,7 +45,6 @@ register(
"num_dof": 2,
"num_basis": 5,
"duration": 20,
"width": 0.025,
"policy_type": "motor",
"zero_start": True,
"policy_kwargs": {
@ -55,7 +54,7 @@ register(
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_ball_in_cup-catch_detpmp-v0")
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
register(
id=f'dmc_reacher-easy_dmp-v0',
@ -86,8 +85,8 @@ register(
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
register(
id=f'dmc_reacher-easy_detpmp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id=f'dmc_reacher-easy_promp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": f"reacher-easy",
"time_limit": 20,
@ -97,7 +96,6 @@ register(
"num_dof": 2,
"num_basis": 5,
"duration": 20,
"width": 0.025,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
@ -108,7 +106,7 @@ register(
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-easy_detpmp-v0")
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
register(
id=f'dmc_reacher-hard_dmp-v0',
@ -139,8 +137,8 @@ register(
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
register(
id=f'dmc_reacher-hard_detpmp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id=f'dmc_reacher-hard_promp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": f"reacher-hard",
"time_limit": 20,
@ -150,7 +148,6 @@ register(
"num_dof": 2,
"num_basis": 5,
"duration": 20,
"width": 0.025,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
@ -161,7 +158,7 @@ register(
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-hard_detpmp-v0")
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-hard_promp-v0")
_dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
@ -196,10 +193,10 @@ for _task in _dmc_cartpole_tasks:
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'dmc_cartpole-{_task}_detpmp-v0'
_env_id = f'dmc_cartpole-{_task}_promp-v0'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": f"cartpole-{_task}",
# "time_limit": 1,
@ -210,7 +207,6 @@ for _task in _dmc_cartpole_tasks:
"num_dof": 1,
"num_basis": 5,
"duration": 10,
"width": 0.025,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
@ -221,7 +217,7 @@ for _task in _dmc_cartpole_tasks:
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_env_id = f'dmc_cartpole-two_poles_dmp-v0'
register(
@ -253,10 +249,10 @@ register(
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'dmc_cartpole-two_poles_detpmp-v0'
_env_id = f'dmc_cartpole-two_poles_promp-v0'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": f"cartpole-two_poles",
# "time_limit": 1,
@ -267,7 +263,6 @@ register(
"num_dof": 1,
"num_basis": 5,
"duration": 10,
"width": 0.025,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
@ -278,7 +273,7 @@ register(
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_env_id = f'dmc_cartpole-three_poles_dmp-v0'
register(
@ -310,10 +305,10 @@ register(
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'dmc_cartpole-three_poles_detpmp-v0'
_env_id = f'dmc_cartpole-three_poles_promp-v0'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": f"cartpole-three_poles",
# "time_limit": 1,
@ -324,7 +319,6 @@ register(
"num_dof": 1,
"num_basis": 5,
"duration": 10,
"width": 0.025,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
@ -335,7 +329,7 @@ register(
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# DeepMind Manipulation
@ -364,8 +358,8 @@ register(
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
register(
id=f'dmc_manipulation-reach_site_detpmp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id=f'dmc_manipulation-reach_site_promp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": f"manipulation-reach_site_features",
# "time_limit": 1,
@ -375,11 +369,10 @@ register(
"num_dof": 9,
"num_basis": 5,
"duration": 10,
"width": 0.025,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True,
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_manipulation-reach_site_detpmp-v0")
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_manipulation-reach_site_promp-v0")

View File

@ -84,7 +84,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
}
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
# OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples):
# env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
# This renders the full MP trajectory
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
@ -128,7 +128,7 @@ if __name__ == '__main__':
example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render)
# Gym + DMC hybrid task provided in the MP framework
example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=render)
example_dmc("dmc_ball_in_cup-catch_promp-v0", seed=10, iterations=1, render=render)
# Custom DMC task
# Different seed, because the episode is longer for this example and the name+seed combo is already registered above

View File

@ -76,7 +76,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
"policy_type": "metaworld", # custom controller type for metaworld environments
}
env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
# OR for a DMP (other mp_kwargs are required, see dmc_examples):
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
@ -122,7 +122,7 @@ if __name__ == '__main__':
example_dmc("button-press-v2", seed=10, iterations=500, render=render)
# MP + MetaWorld hybrid task provided in the our framework
example_dmc("ButtonPressDetPMP-v2", seed=10, iterations=1, render=render)
example_dmc("ButtonPressProMP-v2", seed=10, iterations=1, render=render)
# Custom MetaWorld task
example_custom_dmc_and_mp(seed=10, iterations=1, render=render)

View File

@ -126,7 +126,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
}
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
# OR for a deterministic ProMP:
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
# env = make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
if render:
env.render(mode="human")
@ -147,7 +147,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__':
render = True
render = False
# DMP
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)

View File

@ -6,7 +6,7 @@ def example_mp(env_name, seed=1):
Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered.
For more information on motion primitive specific stuff, look at the mp examples.
Args:
env_name: DetPMP env_id
env_name: ProMP env_id
seed: seed
Returns:
@ -35,7 +35,7 @@ if __name__ == '__main__':
# example_mp("ReacherDMP-v2")
# DetProMP
example_mp("ContinuousMountainCarDetPMP-v0")
example_mp("ReacherDetPMP-v2")
example_mp("FetchReachDenseDetPMP-v1")
example_mp("FetchSlideDenseDetPMP-v1")
example_mp("ContinuousMountainCarProMP-v0")
example_mp("ReacherProMP-v2")
example_mp("FetchReachDenseProMP-v1")
example_mp("FetchSlideDenseProMP-v1")

View File

@ -2,7 +2,7 @@ import numpy as np
from matplotlib import pyplot as plt
from alr_envs import dmc, meta
from alr_envs.utils.make_env_helpers import make_detpmp_env
from alr_envs.utils.make_env_helpers import make_promp_env
# This might work for some environments, however, please verify either way the correct trajectory information
# for your environment are extracted below
@ -26,8 +26,8 @@ mp_kwargs = {
kwargs = dict(time_limit=2, episode_length=100)
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
**kwargs)
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
**kwargs)
# Plot difference between real trajectory and target MP trajectory
env.reset()

View File

@ -3,7 +3,7 @@ from gym import register
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
object_change_mp_wrapper
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
# MetaWorld
@ -12,10 +12,10 @@ _goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "pl
for _task in _goal_change_envs:
task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
_env_id = f'{name}ProMP-{task_id_split[-1]}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": _task,
"wrappers": [goal_change_mp_wrapper.MPWrapper],
@ -24,22 +24,21 @@ for _task in _goal_change_envs:
"num_basis": 5,
"duration": 6.25,
"post_traj_time": 0,
"width": 0.025,
"zero_start": True,
"policy_type": "metaworld",
}
}
)
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
for _task in _object_change_envs:
task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
_env_id = f'{name}ProMP-{task_id_split[-1]}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": _task,
"wrappers": [object_change_mp_wrapper.MPWrapper],
@ -48,13 +47,12 @@ for _task in _object_change_envs:
"num_basis": 5,
"duration": 6.25,
"post_traj_time": 0,
"width": 0.025,
"zero_start": True,
"policy_type": "metaworld",
}
}
)
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
@ -70,10 +68,10 @@ _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press
for _task in _goal_and_object_change_envs:
task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
_env_id = f'{name}ProMP-{task_id_split[-1]}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": _task,
"wrappers": [goal_object_change_mp_wrapper.MPWrapper],
@ -82,22 +80,21 @@ for _task in _goal_and_object_change_envs:
"num_basis": 5,
"duration": 6.25,
"post_traj_time": 0,
"width": 0.025,
"zero_start": True,
"policy_type": "metaworld",
}
}
)
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_goal_and_endeffector_change_envs = ["basketball-v2"]
for _task in _goal_and_endeffector_change_envs:
task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
_env_id = f'{name}ProMP-{task_id_split[-1]}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": _task,
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
@ -106,10 +103,9 @@ for _task in _goal_and_endeffector_change_envs:
"num_basis": 5,
"duration": 6.25,
"post_traj_time": 0,
"width": 0.025,
"zero_start": True,
"policy_type": "metaworld",
}
}
)
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)

View File

@ -8,7 +8,7 @@ These environments are wrapped-versions of their OpenAI-gym counterparts.
|Name| Description|Trajectory Horizon|Action Dimension|Context Dimension
|---|---|---|---|---|
|`ContinuousMountainCarDetPMP-v0`| A DetPmP wrapped version of the ContinuousMountainCar-v0 environment. | 100 | 1
|`ReacherDetPMP-v2`| A DetPmP wrapped version of the Reacher-v2 environment. | 50 | 2
|`FetchSlideDenseDetPMP-v1`| A DetPmP wrapped version of the FetchSlideDense-v1 environment. | 50 | 4
|`FetchReachDenseDetPMP-v1`| A DetPmP wrapped version of the FetchReachDense-v1 environment. | 50 | 4
|`ContinuousMountainCarProMP-v0`| A ProMP wrapped version of the ContinuousMountainCar-v0 environment. | 100 | 1
|`ReacherProMP-v2`| A ProMP wrapped version of the Reacher-v2 environment. | 50 | 2
|`FetchSlideDenseProMP-v1`| A ProMP wrapped version of the FetchSlideDense-v1 environment. | 50 | 4
|`FetchReachDenseProMP-v1`| A ProMP wrapped version of the FetchReachDense-v1 environment. | 50 | 4

View File

@ -3,7 +3,7 @@ from gym.wrappers import FlattenObservation
from . import classic_control, mujoco, robotics
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
# Short Continuous Mountain Car
register(
@ -16,8 +16,8 @@ register(
# Open AI
# Classic Control
register(
id='ContinuousMountainCarDetPMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id='ContinuousMountainCarProMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": "alr_envs:MountainCarContinuous-v1",
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
@ -26,7 +26,6 @@ register(
"num_basis": 4,
"duration": 2,
"post_traj_time": 0,
"width": 0.02,
"zero_start": True,
"policy_type": "motor",
"policy_kwargs": {
@ -36,11 +35,11 @@ register(
}
}
)
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ContinuousMountainCarDetPMP-v1")
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ContinuousMountainCarProMP-v1")
register(
id='ContinuousMountainCarDetPMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id='ContinuousMountainCarProMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
@ -49,7 +48,6 @@ register(
"num_basis": 4,
"duration": 19.98,
"post_traj_time": 0,
"width": 0.02,
"zero_start": True,
"policy_type": "motor",
"policy_kwargs": {
@ -59,11 +57,11 @@ register(
}
}
)
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ContinuousMountainCarDetPMP-v0")
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ContinuousMountainCarProMP-v0")
register(
id='ReacherDetPMP-v2',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id='ReacherProMP-v2',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": "gym.envs.mujoco:Reacher-v2",
"wrappers": [mujoco.reacher_v2.MPWrapper],
@ -72,7 +70,6 @@ register(
"num_basis": 6,
"duration": 1,
"post_traj_time": 0,
"width": 0.02,
"zero_start": True,
"policy_type": "motor",
"policy_kwargs": {
@ -82,11 +79,11 @@ register(
}
}
)
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ReacherDetPMP-v2")
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ReacherProMP-v2")
register(
id='FetchSlideDenseDetPMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id='FetchSlideDenseProMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": "gym.envs.robotics:FetchSlideDense-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
@ -95,17 +92,16 @@ register(
"num_basis": 5,
"duration": 2,
"post_traj_time": 0,
"width": 0.02,
"zero_start": True,
"policy_type": "position"
}
}
)
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchSlideDenseDetPMP-v1")
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchSlideDenseProMP-v1")
register(
id='FetchSlideDetPMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id='FetchSlideProMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": "gym.envs.robotics:FetchSlide-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
@ -114,17 +110,16 @@ register(
"num_basis": 5,
"duration": 2,
"post_traj_time": 0,
"width": 0.02,
"zero_start": True,
"policy_type": "position"
}
}
)
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchSlideDetPMP-v1")
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchSlideProMP-v1")
register(
id='FetchReachDenseDetPMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id='FetchReachDenseProMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": "gym.envs.robotics:FetchReachDense-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
@ -133,17 +128,16 @@ register(
"num_basis": 5,
"duration": 2,
"post_traj_time": 0,
"width": 0.02,
"zero_start": True,
"policy_type": "position"
}
}
)
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchReachDenseDetPMP-v1")
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchReachDenseProMP-v1")
register(
id='FetchReachDetPMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
id='FetchReachProMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": "gym.envs.robotics:FetchReach-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
@ -152,10 +146,9 @@ register(
"num_basis": 5,
"duration": 2,
"post_traj_time": 0,
"width": 0.02,
"zero_start": True,
"policy_type": "position"
}
}
)
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchReachDetPMP-v1")
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchReachProMP-v1")

View File

@ -1,3 +1,4 @@
import warnings
from typing import Iterable, Type, Union
import gym
@ -5,7 +6,6 @@ import numpy as np
from gym.envs.registration import EnvSpec
from mp_env_api import MPEnvWrapper
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
@ -48,6 +48,11 @@ def make(env_id: str, seed, **kwargs):
Returns: Gym environment
"""
if any([det_pmp in env_id for det_pmp in ["DetPMP", "detpmp"]]):
warnings.warn("DetPMP is deprecated and converted to ProMP")
env_id = env_id.replace("DetPMP", "ProMP")
env_id = env_id.replace("detpmp", "promp")
try:
# Add seed to kwargs in case it is a predefined gym+dmc hybrid environment.
if env_id.startswith("dmc"):
@ -153,26 +158,6 @@ def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwar
return ProMPWrapper(_env, **mp_kwargs)
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
"""
This can also be used standalone for manually building a custom Det ProMP environment.
Args:
env_id: base_env_name,
wrappers: list of wrappers (at least an MPEnvWrapper),
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
Returns: Det ProMP wrapped gym env
"""
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
_verify_dof(_env, mp_kwargs.get("num_dof"))
return DetPMPWrapper(_env, **mp_kwargs)
def make_dmp_env_helper(**kwargs):
"""
Helper function for registering a DMP gym environments.
@ -212,26 +197,6 @@ def make_promp_env_helper(**kwargs):
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
def make_detpmp_env_helper(**kwargs):
"""
Helper function for registering ProMP gym environments.
This can also be used standalone for manually building a custom ProMP environment.
Args:
**kwargs: expects at least the following:
{
"name": base_env_name,
"wrappers": list of wrappers (at least an MPEnvWrapper),
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
}
Returns: DMP wrapped gym env
"""
seed = kwargs.pop("seed", None)
return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
"""
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.

View File

@ -98,8 +98,8 @@ class TestMPEnvironments(unittest.TestCase):
with self.subTest(msg=env_id):
self._run_env(env_id)
with self.subTest(msg="DetPMP"):
for env_id in alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
with self.subTest(msg="ProMP"):
for env_id in alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
with self.subTest(msg=env_id):
self._run_env(env_id)
@ -110,8 +110,8 @@ class TestMPEnvironments(unittest.TestCase):
with self.subTest(msg=env_id):
self._run_env(env_id)
with self.subTest(msg="DetPMP"):
for env_id in alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
with self.subTest(msg="ProMP"):
for env_id in alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
with self.subTest(msg=env_id):
self._run_env(env_id)
@ -122,8 +122,8 @@ class TestMPEnvironments(unittest.TestCase):
with self.subTest(msg=env_id):
self._run_env(env_id)
with self.subTest(msg="DetPMP"):
for env_id in alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
with self.subTest(msg="ProMP"):
for env_id in alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
with self.subTest(msg=env_id):
self._run_env(env_id)
@ -134,8 +134,8 @@ class TestMPEnvironments(unittest.TestCase):
with self.subTest(msg=env_id):
self._run_env(env_id)
with self.subTest(msg="DetPMP"):
for env_id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
with self.subTest(msg="ProMP"):
for env_id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
with self.subTest(msg=env_id):
self._run_env(env_id)
@ -143,29 +143,29 @@ class TestMPEnvironments(unittest.TestCase):
"""Tests that identical seeds produce identical trajectories for ALR MP Envs."""
with self.subTest(msg="DMP"):
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
with self.subTest(msg="DetPMP"):
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
with self.subTest(msg="ProMP"):
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
def test_openai_environment_determinism(self):
"""Tests that identical seeds produce identical trajectories for OpenAI gym MP Envs."""
with self.subTest(msg="DMP"):
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
with self.subTest(msg="DetPMP"):
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
with self.subTest(msg="ProMP"):
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
def test_dmc_environment_determinism(self):
"""Tests that identical seeds produce identical trajectories for DMC MP Envs."""
with self.subTest(msg="DMP"):
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
with self.subTest(msg="DetPMP"):
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
with self.subTest(msg="ProMP"):
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
def test_metaworld_environment_determinism(self):
"""Tests that identical seeds produce identical trajectories for Metaworld MP Envs."""
with self.subTest(msg="DMP"):
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
with self.subTest(msg="DetPMP"):
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
with self.subTest(msg="ProMP"):
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
if __name__ == '__main__':

View File

@ -81,13 +81,13 @@ class TestStepMetaWorlEnvironments(unittest.TestCase):
def _verify_done(self, done):
self.assertIsInstance(done, bool, f"Returned {done} as done flag, expected bool.")
def test_dmc_functionality(self):
def test_metaworld_functionality(self):
"""Tests that environments runs without errors using random actions."""
for env_id in ALL_ENVS:
with self.subTest(msg=env_id):
self._run_env(env_id)
def test_dmc_determinism(self):
def test_metaworld_determinism(self):
"""Tests that identical seeds produce identical trajectories."""
seed = 0
# Iterate over two trajectories, which should have the same state and action sequence