replaced all detpmp with promp
This commit is contained in:
parent
655d52aa35
commit
8b88ce3476
10
README.md
10
README.md
@ -6,14 +6,14 @@ Besides, some custom environments we also provide support for the benchmark suit
|
||||
[DeepMind Control](https://deepmind.com/research/publications/2020/dm-control-Software-and-Tasks-for-Continuous-Control)
|
||||
(DMC), and [Metaworld](https://meta-world.github.io/). Custom (Mujoco) gym environment can be created according
|
||||
to [this guide](https://github.com/openai/gym/blob/master/docs/creating-environments.md). Unlike existing libraries, we
|
||||
further support to control agents with Dynamic Movement Primitives (DMPs) and Probabilistic Movement Primitives (DetPMP,
|
||||
further support to control agents with Dynamic Movement Primitives (DMPs) and Probabilistic Movement Primitives (ProMP,
|
||||
we only consider the mean usually).
|
||||
|
||||
## Motion Primitive Environments (Episodic environments)
|
||||
|
||||
Unlike step-based environments, motion primitive (MP) environments are closer related to stochastic search, black box
|
||||
optimization and methods that often used in robotics. MP environments are trajectory-based and always execute a full
|
||||
trajectory, which is generated by a Dynamic Motion Primitive (DMP) or a Probabilistic Motion Primitive (DetPMP). The
|
||||
trajectory, which is generated by a Dynamic Motion Primitive (DMP) or a Probabilistic Motion Primitive (ProMP). The
|
||||
generated trajectory is translated into individual step-wise actions by a controller. The exact choice of controller is,
|
||||
however, dependent on the type of environment. We currently support position, velocity, and PD-Controllers for position,
|
||||
velocity and torque control, respectively. The goal of all MP environments is still to learn a policy. Yet, an action
|
||||
@ -82,7 +82,7 @@ trajectory.
|
||||
```python
|
||||
import alr_envs
|
||||
|
||||
env = alr_envs.make('HoleReacherDetPMP-v0', seed=1)
|
||||
env = alr_envs.make('HoleReacherProMP-v0', seed=1)
|
||||
# render() can be called once in the beginning with all necessary arguments. To turn it of again just call render(None).
|
||||
env.render()
|
||||
|
||||
@ -96,7 +96,7 @@ for i in range(5):
|
||||
```
|
||||
|
||||
To show all available environments, we provide some additional convenience. Each value will return a dictionary with two
|
||||
keys `DMP` and `DetPMP` that store a list of available environment names.
|
||||
keys `DMP` and `ProMP` that store a list of available environment names.
|
||||
|
||||
```python
|
||||
import alr_envs
|
||||
@ -194,7 +194,7 @@ mp_kwargs = {...}
|
||||
kwargs = {...}
|
||||
env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs)
|
||||
# OR for a deterministic ProMP (other mp_kwargs are required):
|
||||
# env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||
|
||||
rewards = 0
|
||||
obs = env.reset()
|
||||
|
@ -1,5 +1,5 @@
|
||||
from alr_envs import dmc, meta, open_ai
|
||||
from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_promp_env, make_rank
|
||||
from alr_envs.utils.make_env_helpers import make, make_dmp_env, make_promp_env, make_rank
|
||||
from alr_envs.utils import make_dmc
|
||||
|
||||
# Convenience function for all MP environments
|
||||
|
@ -9,7 +9,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
||||
from .mujoco.reacher.alr_reacher import ALRReacherEnv
|
||||
from .mujoco.reacher.balancing import BalancingEnv
|
||||
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
|
||||
# Classic Control
|
||||
## Simple Reacher
|
||||
@ -213,8 +213,12 @@ for _v in _versions:
|
||||
"duration": 2,
|
||||
"alpha_phase": 2,
|
||||
"learn_goal": True,
|
||||
"policy_type": "velocity",
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"policy_kwargs": {
|
||||
"p_gains": .6,
|
||||
"d_gains": .075
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
@ -233,33 +237,16 @@ for _v in _versions:
|
||||
"duration": 2,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 1,
|
||||
"zero_start": True
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": .6,
|
||||
"d_gains": .075
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
_env_id = f'{_name[0]}DetPMP-{_name[1]}'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"alr_envs:{_v}",
|
||||
"wrappers": [classic_control.simple_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2 if "long" not in _v.lower() else 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"width": 0.025,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
||||
|
||||
# Viapoint reacher
|
||||
register(
|
||||
id='ViaPointReacherDMP-v0',
|
||||
@ -291,7 +278,7 @@ register(
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"policy_type": "motor",
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 1,
|
||||
"zero_start": True
|
||||
}
|
||||
@ -299,26 +286,6 @@ register(
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
|
||||
|
||||
register(
|
||||
id='ViaPointReacherDetPMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": "alr_envs:ViaPointReacher-v0",
|
||||
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"width": 0.025,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ViaPointReacherDetPMP-v0")
|
||||
|
||||
## Hole Reacher
|
||||
_versions = ["v0", "v1", "v2"]
|
||||
for _v in _versions:
|
||||
@ -363,23 +330,3 @@ for _v in _versions:
|
||||
}
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
_env_id = f'HoleReacherDetPMP-{_v}'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"alr_envs:HoleReacher-{_v}",
|
||||
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"width": 0.025,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
||||
|
@ -18,4 +18,4 @@
|
||||
|`ALRBallInACupDMP-v0`| A DMP provides a trajectory for the `ALRBallInACup-v0` task. | 4000 | 35
|
||||
|`ALRBallInACupGoalDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupGoal-v0` task. | 4000 | 35 | 3
|
||||
|
||||
[//]: |`HoleReacherDetPMP-v0`|
|
||||
[//]: |`HoleReacherProMPP-v0`|
|
@ -5,7 +5,6 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from gym.utils import seeding
|
||||
|
||||
from alr_envs.alr.classic_control.utils import check_self_collision
|
||||
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
||||
|
||||
|
||||
|
@ -1,107 +0,0 @@
|
||||
from alr_envs.alr.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||
|
||||
|
||||
def make_contextual_env(rank, seed=0):
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||
:param seed: (int) the initial seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:returns a function that generates an environment
|
||||
"""
|
||||
|
||||
def _init():
|
||||
env = ALRBallInACupEnv(reward_type="contextual_goal")
|
||||
|
||||
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
||||
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
|
||||
|
||||
env.seed(seed + rank)
|
||||
return env
|
||||
|
||||
return _init
|
||||
|
||||
|
||||
def _make_env(rank, seed=0):
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||
:param seed: (int) the initial seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:returns a function that generates an environment
|
||||
"""
|
||||
|
||||
def _init():
|
||||
env = ALRBallInACupEnv(reward_type="simple")
|
||||
|
||||
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
||||
policy_type="motor", weights_scale=0.2, zero_start=True, zero_goal=True)
|
||||
|
||||
env.seed(seed + rank)
|
||||
return env
|
||||
|
||||
return _init
|
||||
|
||||
|
||||
def make_simple_env(rank, seed=0):
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||
:param seed: (int) the initial seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:returns a function that generates an environment
|
||||
"""
|
||||
|
||||
def _init():
|
||||
env = ALRBallInACupEnv(reward_type="simple")
|
||||
|
||||
env = DetPMPWrapper(env, num_dof=3, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
||||
policy_type="motor", weights_scale=0.25, zero_start=True, zero_goal=True, off=-0.1)
|
||||
|
||||
env.seed(seed + rank)
|
||||
return env
|
||||
|
||||
return _init
|
||||
|
||||
|
||||
def make_simple_dmp_env(rank, seed=0):
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||
:param seed: (int) the initial seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:returns a function that generates an environment
|
||||
"""
|
||||
|
||||
def _init():
|
||||
_env = ALRBallInACupEnv(reward_type="simple")
|
||||
|
||||
_env = DmpWrapper(_env,
|
||||
num_dof=3,
|
||||
num_basis=5,
|
||||
duration=3.5,
|
||||
post_traj_time=4.5,
|
||||
bandwidth_factor=2.5,
|
||||
dt=_env.dt,
|
||||
learn_goal=False,
|
||||
alpha_phase=3,
|
||||
start_pos=_env.start_pos[1::2],
|
||||
final_pos=_env.start_pos[1::2],
|
||||
policy_type="motor",
|
||||
weights_scale=100,
|
||||
)
|
||||
|
||||
_env.seed(seed + rank)
|
||||
return _env
|
||||
|
||||
return _init
|
@ -1,72 +0,0 @@
|
||||
from alr_envs.utils.mps.detpmp_wrapper import DetPMPWrapper
|
||||
from alr_envs.alr.mujoco.beerpong.beerpong import ALRBeerpongEnv
|
||||
from alr_envs.alr.mujoco.beerpong.beerpong_simple import ALRBeerpongEnv as ALRBeerpongEnvSimple
|
||||
|
||||
|
||||
def make_contextual_env(rank, seed=0):
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||
:param seed: (int) the initial seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:returns a function that generates an environment
|
||||
"""
|
||||
|
||||
def _init():
|
||||
env = ALRBeerpongEnv()
|
||||
|
||||
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
||||
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
|
||||
|
||||
env.seed(seed + rank)
|
||||
return env
|
||||
|
||||
return _init
|
||||
|
||||
|
||||
def _make_env(rank, seed=0):
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||
:param seed: (int) the initial seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:returns a function that generates an environment
|
||||
"""
|
||||
|
||||
def _init():
|
||||
env = ALRBeerpongEnvSimple()
|
||||
|
||||
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
||||
policy_type="motor", weights_scale=0.25, zero_start=True, zero_goal=True)
|
||||
|
||||
env.seed(seed + rank)
|
||||
return env
|
||||
|
||||
return _init
|
||||
|
||||
|
||||
def make_simple_env(rank, seed=0):
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||
:param seed: (int) the initial seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:returns a function that generates an environment
|
||||
"""
|
||||
|
||||
def _init():
|
||||
env = ALRBeerpongEnvSimple()
|
||||
|
||||
env = DetPMPWrapper(env, num_dof=3, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
||||
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
|
||||
|
||||
env.seed(seed + rank)
|
||||
return env
|
||||
|
||||
return _init
|
@ -11,9 +11,9 @@ environments in order to use our Motion Primitive gym interface with them.
|
||||
|
||||
|Name| Description|Trajectory Horizon|Action Dimension|Context Dimension
|
||||
|---|---|---|---|---|
|
||||
|`dmc_ball_in_cup-catch_detpmp-v0`| A DetPmP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000 | 10 | 2
|
||||
|`dmc_ball_in_cup-catch_promp-v0`| A ProMP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000 | 10 | 2
|
||||
|`dmc_ball_in_cup-catch_dmp-v0`| A DMP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000| 10 | 2
|
||||
|`dmc_reacher-easy_detpmp-v0`| A DetPmP wrapped version of the "easy" task for the "reacher" environment. | 1000 | 10 | 4
|
||||
|`dmc_reacher-easy_promp-v0`| A ProMP wrapped version of the "easy" task for the "reacher" environment. | 1000 | 10 | 4
|
||||
|`dmc_reacher-easy_dmp-v0`| A DMP wrapped version of the "easy" task for the "reacher" environment. | 1000| 10 | 4
|
||||
|`dmc_reacher-hard_detpmp-v0`| A DetPmP wrapped version of the "hard" task for the "reacher" environment.| 1000 | 10 | 4
|
||||
|`dmc_reacher-hard_promp-v0`| A ProMP wrapped version of the "hard" task for the "reacher" environment.| 1000 | 10 | 4
|
||||
|`dmc_reacher-hard_dmp-v0`| A DMP wrapped version of the "hard" task for the "reacher" environment. | 1000 | 10 | 4
|
||||
|
@ -1,6 +1,6 @@
|
||||
from . import manipulation, suite
|
||||
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
|
||||
from gym.envs.registration import register
|
||||
|
||||
@ -34,8 +34,8 @@ register(
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
|
||||
|
||||
register(
|
||||
id=f'dmc_ball_in_cup-catch_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id=f'dmc_ball_in_cup-catch_promp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": f"ball_in_cup-catch",
|
||||
"time_limit": 20,
|
||||
@ -45,7 +45,6 @@ register(
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
@ -55,7 +54,7 @@ register(
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_ball_in_cup-catch_detpmp-v0")
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
|
||||
|
||||
register(
|
||||
id=f'dmc_reacher-easy_dmp-v0',
|
||||
@ -86,8 +85,8 @@ register(
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
|
||||
|
||||
register(
|
||||
id=f'dmc_reacher-easy_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id=f'dmc_reacher-easy_promp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": f"reacher-easy",
|
||||
"time_limit": 20,
|
||||
@ -97,7 +96,6 @@ register(
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
@ -108,7 +106,7 @@ register(
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-easy_detpmp-v0")
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
|
||||
|
||||
register(
|
||||
id=f'dmc_reacher-hard_dmp-v0',
|
||||
@ -139,8 +137,8 @@ register(
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
|
||||
|
||||
register(
|
||||
id=f'dmc_reacher-hard_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id=f'dmc_reacher-hard_promp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": f"reacher-hard",
|
||||
"time_limit": 20,
|
||||
@ -150,7 +148,6 @@ register(
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
@ -161,7 +158,7 @@ register(
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-hard_detpmp-v0")
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-hard_promp-v0")
|
||||
|
||||
_dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
|
||||
|
||||
@ -196,10 +193,10 @@ for _task in _dmc_cartpole_tasks:
|
||||
)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
_env_id = f'dmc_cartpole-{_task}_detpmp-v0'
|
||||
_env_id = f'dmc_cartpole-{_task}_promp-v0'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": f"cartpole-{_task}",
|
||||
# "time_limit": 1,
|
||||
@ -210,7 +207,6 @@ for _task in _dmc_cartpole_tasks:
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
@ -221,7 +217,7 @@ for _task in _dmc_cartpole_tasks:
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
_env_id = f'dmc_cartpole-two_poles_dmp-v0'
|
||||
register(
|
||||
@ -253,10 +249,10 @@ register(
|
||||
)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
_env_id = f'dmc_cartpole-two_poles_detpmp-v0'
|
||||
_env_id = f'dmc_cartpole-two_poles_promp-v0'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": f"cartpole-two_poles",
|
||||
# "time_limit": 1,
|
||||
@ -267,7 +263,6 @@ register(
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
@ -278,7 +273,7 @@ register(
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
_env_id = f'dmc_cartpole-three_poles_dmp-v0'
|
||||
register(
|
||||
@ -310,10 +305,10 @@ register(
|
||||
)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
_env_id = f'dmc_cartpole-three_poles_detpmp-v0'
|
||||
_env_id = f'dmc_cartpole-three_poles_promp-v0'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": f"cartpole-three_poles",
|
||||
# "time_limit": 1,
|
||||
@ -324,7 +319,6 @@ register(
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
@ -335,7 +329,7 @@ register(
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# DeepMind Manipulation
|
||||
|
||||
@ -364,8 +358,8 @@ register(
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
|
||||
|
||||
register(
|
||||
id=f'dmc_manipulation-reach_site_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id=f'dmc_manipulation-reach_site_promp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": f"manipulation-reach_site_features",
|
||||
# "time_limit": 1,
|
||||
@ -375,11 +369,10 @@ register(
|
||||
"num_dof": 9,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_manipulation-reach_site_detpmp-v0")
|
||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_manipulation-reach_site_promp-v0")
|
||||
|
@ -84,7 +84,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
}
|
||||
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||
# OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples):
|
||||
# env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||
|
||||
# This renders the full MP trajectory
|
||||
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||
@ -128,7 +128,7 @@ if __name__ == '__main__':
|
||||
example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render)
|
||||
|
||||
# Gym + DMC hybrid task provided in the MP framework
|
||||
example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=render)
|
||||
example_dmc("dmc_ball_in_cup-catch_promp-v0", seed=10, iterations=1, render=render)
|
||||
|
||||
# Custom DMC task
|
||||
# Different seed, because the episode is longer for this example and the name+seed combo is already registered above
|
||||
|
@ -76,7 +76,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
"policy_type": "metaworld", # custom controller type for metaworld environments
|
||||
}
|
||||
|
||||
env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||
env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||
# OR for a DMP (other mp_kwargs are required, see dmc_examples):
|
||||
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||
|
||||
@ -122,7 +122,7 @@ if __name__ == '__main__':
|
||||
example_dmc("button-press-v2", seed=10, iterations=500, render=render)
|
||||
|
||||
# MP + MetaWorld hybrid task provided in the our framework
|
||||
example_dmc("ButtonPressDetPMP-v2", seed=10, iterations=1, render=render)
|
||||
example_dmc("ButtonPressProMP-v2", seed=10, iterations=1, render=render)
|
||||
|
||||
# Custom MetaWorld task
|
||||
example_custom_dmc_and_mp(seed=10, iterations=1, render=render)
|
||||
|
@ -126,7 +126,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
}
|
||||
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||
# OR for a deterministic ProMP:
|
||||
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||
# env = make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||
|
||||
if render:
|
||||
env.render(mode="human")
|
||||
@ -147,7 +147,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = True
|
||||
render = False
|
||||
# DMP
|
||||
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||
|
||||
|
@ -6,7 +6,7 @@ def example_mp(env_name, seed=1):
|
||||
Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered.
|
||||
For more information on motion primitive specific stuff, look at the mp examples.
|
||||
Args:
|
||||
env_name: DetPMP env_id
|
||||
env_name: ProMP env_id
|
||||
seed: seed
|
||||
|
||||
Returns:
|
||||
@ -35,7 +35,7 @@ if __name__ == '__main__':
|
||||
# example_mp("ReacherDMP-v2")
|
||||
|
||||
# DetProMP
|
||||
example_mp("ContinuousMountainCarDetPMP-v0")
|
||||
example_mp("ReacherDetPMP-v2")
|
||||
example_mp("FetchReachDenseDetPMP-v1")
|
||||
example_mp("FetchSlideDenseDetPMP-v1")
|
||||
example_mp("ContinuousMountainCarProMP-v0")
|
||||
example_mp("ReacherProMP-v2")
|
||||
example_mp("FetchReachDenseProMP-v1")
|
||||
example_mp("FetchSlideDenseProMP-v1")
|
||||
|
@ -2,7 +2,7 @@ import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from alr_envs import dmc, meta
|
||||
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
||||
from alr_envs.utils.make_env_helpers import make_promp_env
|
||||
|
||||
# This might work for some environments, however, please verify either way the correct trajectory information
|
||||
# for your environment are extracted below
|
||||
@ -26,8 +26,8 @@ mp_kwargs = {
|
||||
|
||||
kwargs = dict(time_limit=2, episode_length=100)
|
||||
|
||||
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
||||
**kwargs)
|
||||
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
||||
**kwargs)
|
||||
|
||||
# Plot difference between real trajectory and target MP trajectory
|
||||
env.reset()
|
||||
|
@ -3,7 +3,7 @@ from gym import register
|
||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||
object_change_mp_wrapper
|
||||
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
|
||||
# MetaWorld
|
||||
|
||||
@ -12,10 +12,10 @@ _goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "pl
|
||||
for _task in _goal_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": _task,
|
||||
"wrappers": [goal_change_mp_wrapper.MPWrapper],
|
||||
@ -24,22 +24,21 @@ for _task in _goal_change_envs:
|
||||
"num_basis": 5,
|
||||
"duration": 6.25,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.025,
|
||||
"zero_start": True,
|
||||
"policy_type": "metaworld",
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
||||
for _task in _object_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": _task,
|
||||
"wrappers": [object_change_mp_wrapper.MPWrapper],
|
||||
@ -48,13 +47,12 @@ for _task in _object_change_envs:
|
||||
"num_basis": 5,
|
||||
"duration": 6.25,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.025,
|
||||
"zero_start": True,
|
||||
"policy_type": "metaworld",
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
||||
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
||||
@ -70,10 +68,10 @@ _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press
|
||||
for _task in _goal_and_object_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": _task,
|
||||
"wrappers": [goal_object_change_mp_wrapper.MPWrapper],
|
||||
@ -82,22 +80,21 @@ for _task in _goal_and_object_change_envs:
|
||||
"num_basis": 5,
|
||||
"duration": 6.25,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.025,
|
||||
"zero_start": True,
|
||||
"policy_type": "metaworld",
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
||||
for _task in _goal_and_endeffector_change_envs:
|
||||
task_id_split = _task.split("-")
|
||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
|
||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||
register(
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": _task,
|
||||
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
|
||||
@ -106,10 +103,9 @@ for _task in _goal_and_endeffector_change_envs:
|
||||
"num_basis": 5,
|
||||
"duration": 6.25,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.025,
|
||||
"zero_start": True,
|
||||
"policy_type": "metaworld",
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
@ -8,7 +8,7 @@ These environments are wrapped-versions of their OpenAI-gym counterparts.
|
||||
|
||||
|Name| Description|Trajectory Horizon|Action Dimension|Context Dimension
|
||||
|---|---|---|---|---|
|
||||
|`ContinuousMountainCarDetPMP-v0`| A DetPmP wrapped version of the ContinuousMountainCar-v0 environment. | 100 | 1
|
||||
|`ReacherDetPMP-v2`| A DetPmP wrapped version of the Reacher-v2 environment. | 50 | 2
|
||||
|`FetchSlideDenseDetPMP-v1`| A DetPmP wrapped version of the FetchSlideDense-v1 environment. | 50 | 4
|
||||
|`FetchReachDenseDetPMP-v1`| A DetPmP wrapped version of the FetchReachDense-v1 environment. | 50 | 4
|
||||
|`ContinuousMountainCarProMP-v0`| A ProMP wrapped version of the ContinuousMountainCar-v0 environment. | 100 | 1
|
||||
|`ReacherProMP-v2`| A ProMP wrapped version of the Reacher-v2 environment. | 50 | 2
|
||||
|`FetchSlideDenseProMP-v1`| A ProMP wrapped version of the FetchSlideDense-v1 environment. | 50 | 4
|
||||
|`FetchReachDenseProMP-v1`| A ProMP wrapped version of the FetchReachDense-v1 environment. | 50 | 4
|
||||
|
@ -3,7 +3,7 @@ from gym.wrappers import FlattenObservation
|
||||
|
||||
from . import classic_control, mujoco, robotics
|
||||
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
|
||||
# Short Continuous Mountain Car
|
||||
register(
|
||||
@ -16,8 +16,8 @@ register(
|
||||
# Open AI
|
||||
# Classic Control
|
||||
register(
|
||||
id='ContinuousMountainCarDetPMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id='ContinuousMountainCarProMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:MountainCarContinuous-v1",
|
||||
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
||||
@ -26,7 +26,6 @@ register(
|
||||
"num_basis": 4,
|
||||
"duration": 2,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.02,
|
||||
"zero_start": True,
|
||||
"policy_type": "motor",
|
||||
"policy_kwargs": {
|
||||
@ -36,11 +35,11 @@ register(
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ContinuousMountainCarDetPMP-v1")
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ContinuousMountainCarProMP-v1")
|
||||
|
||||
register(
|
||||
id='ContinuousMountainCarDetPMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id='ContinuousMountainCarProMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
||||
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
||||
@ -49,7 +48,6 @@ register(
|
||||
"num_basis": 4,
|
||||
"duration": 19.98,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.02,
|
||||
"zero_start": True,
|
||||
"policy_type": "motor",
|
||||
"policy_kwargs": {
|
||||
@ -59,11 +57,11 @@ register(
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ContinuousMountainCarDetPMP-v0")
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ContinuousMountainCarProMP-v0")
|
||||
|
||||
register(
|
||||
id='ReacherDetPMP-v2',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id='ReacherProMP-v2',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": "gym.envs.mujoco:Reacher-v2",
|
||||
"wrappers": [mujoco.reacher_v2.MPWrapper],
|
||||
@ -72,7 +70,6 @@ register(
|
||||
"num_basis": 6,
|
||||
"duration": 1,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.02,
|
||||
"zero_start": True,
|
||||
"policy_type": "motor",
|
||||
"policy_kwargs": {
|
||||
@ -82,11 +79,11 @@ register(
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ReacherDetPMP-v2")
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ReacherProMP-v2")
|
||||
|
||||
register(
|
||||
id='FetchSlideDenseDetPMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id='FetchSlideDenseProMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||
@ -95,17 +92,16 @@ register(
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.02,
|
||||
"zero_start": True,
|
||||
"policy_type": "position"
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchSlideDenseDetPMP-v1")
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchSlideDenseProMP-v1")
|
||||
|
||||
register(
|
||||
id='FetchSlideDetPMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id='FetchSlideProMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": "gym.envs.robotics:FetchSlide-v1",
|
||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||
@ -114,17 +110,16 @@ register(
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.02,
|
||||
"zero_start": True,
|
||||
"policy_type": "position"
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchSlideDetPMP-v1")
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchSlideProMP-v1")
|
||||
|
||||
register(
|
||||
id='FetchReachDenseDetPMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id='FetchReachDenseProMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": "gym.envs.robotics:FetchReachDense-v1",
|
||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||
@ -133,17 +128,16 @@ register(
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.02,
|
||||
"zero_start": True,
|
||||
"policy_type": "position"
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchReachDenseDetPMP-v1")
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchReachDenseProMP-v1")
|
||||
|
||||
register(
|
||||
id='FetchReachDetPMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
id='FetchReachProMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": "gym.envs.robotics:FetchReach-v1",
|
||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||
@ -152,10 +146,9 @@ register(
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"post_traj_time": 0,
|
||||
"width": 0.02,
|
||||
"zero_start": True,
|
||||
"policy_type": "position"
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchReachDetPMP-v1")
|
||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchReachProMP-v1")
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import Iterable, Type, Union
|
||||
|
||||
import gym
|
||||
@ -5,7 +6,6 @@ import numpy as np
|
||||
from gym.envs.registration import EnvSpec
|
||||
|
||||
from mp_env_api import MPEnvWrapper
|
||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
|
||||
|
||||
@ -48,6 +48,11 @@ def make(env_id: str, seed, **kwargs):
|
||||
Returns: Gym environment
|
||||
|
||||
"""
|
||||
if any([det_pmp in env_id for det_pmp in ["DetPMP", "detpmp"]]):
|
||||
warnings.warn("DetPMP is deprecated and converted to ProMP")
|
||||
env_id = env_id.replace("DetPMP", "ProMP")
|
||||
env_id = env_id.replace("detpmp", "promp")
|
||||
|
||||
try:
|
||||
# Add seed to kwargs in case it is a predefined gym+dmc hybrid environment.
|
||||
if env_id.startswith("dmc"):
|
||||
@ -153,26 +158,6 @@ def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwar
|
||||
return ProMPWrapper(_env, **mp_kwargs)
|
||||
|
||||
|
||||
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
||||
"""
|
||||
This can also be used standalone for manually building a custom Det ProMP environment.
|
||||
Args:
|
||||
env_id: base_env_name,
|
||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
||||
|
||||
Returns: Det ProMP wrapped gym env
|
||||
|
||||
"""
|
||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||
|
||||
_verify_dof(_env, mp_kwargs.get("num_dof"))
|
||||
|
||||
return DetPMPWrapper(_env, **mp_kwargs)
|
||||
|
||||
|
||||
def make_dmp_env_helper(**kwargs):
|
||||
"""
|
||||
Helper function for registering a DMP gym environments.
|
||||
@ -212,26 +197,6 @@ def make_promp_env_helper(**kwargs):
|
||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||
|
||||
|
||||
def make_detpmp_env_helper(**kwargs):
|
||||
"""
|
||||
Helper function for registering ProMP gym environments.
|
||||
This can also be used standalone for manually building a custom ProMP environment.
|
||||
Args:
|
||||
**kwargs: expects at least the following:
|
||||
{
|
||||
"name": base_env_name,
|
||||
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
||||
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
|
||||
}
|
||||
|
||||
Returns: DMP wrapped gym env
|
||||
|
||||
"""
|
||||
seed = kwargs.pop("seed", None)
|
||||
return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||
|
||||
|
||||
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
|
||||
"""
|
||||
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
|
||||
|
@ -98,8 +98,8 @@ class TestMPEnvironments(unittest.TestCase):
|
||||
with self.subTest(msg=env_id):
|
||||
self._run_env(env_id)
|
||||
|
||||
with self.subTest(msg="DetPMP"):
|
||||
for env_id in alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
||||
with self.subTest(msg="ProMP"):
|
||||
for env_id in alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
|
||||
with self.subTest(msg=env_id):
|
||||
self._run_env(env_id)
|
||||
|
||||
@ -110,8 +110,8 @@ class TestMPEnvironments(unittest.TestCase):
|
||||
with self.subTest(msg=env_id):
|
||||
self._run_env(env_id)
|
||||
|
||||
with self.subTest(msg="DetPMP"):
|
||||
for env_id in alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
||||
with self.subTest(msg="ProMP"):
|
||||
for env_id in alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
|
||||
with self.subTest(msg=env_id):
|
||||
self._run_env(env_id)
|
||||
|
||||
@ -122,8 +122,8 @@ class TestMPEnvironments(unittest.TestCase):
|
||||
with self.subTest(msg=env_id):
|
||||
self._run_env(env_id)
|
||||
|
||||
with self.subTest(msg="DetPMP"):
|
||||
for env_id in alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
||||
with self.subTest(msg="ProMP"):
|
||||
for env_id in alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
|
||||
with self.subTest(msg=env_id):
|
||||
self._run_env(env_id)
|
||||
|
||||
@ -134,8 +134,8 @@ class TestMPEnvironments(unittest.TestCase):
|
||||
with self.subTest(msg=env_id):
|
||||
self._run_env(env_id)
|
||||
|
||||
with self.subTest(msg="DetPMP"):
|
||||
for env_id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
||||
with self.subTest(msg="ProMP"):
|
||||
for env_id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
|
||||
with self.subTest(msg=env_id):
|
||||
self._run_env(env_id)
|
||||
|
||||
@ -143,29 +143,29 @@ class TestMPEnvironments(unittest.TestCase):
|
||||
"""Tests that identical seeds produce identical trajectories for ALR MP Envs."""
|
||||
with self.subTest(msg="DMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||
with self.subTest(msg="DetPMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
||||
with self.subTest(msg="ProMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
|
||||
|
||||
def test_openai_environment_determinism(self):
|
||||
"""Tests that identical seeds produce identical trajectories for OpenAI gym MP Envs."""
|
||||
with self.subTest(msg="DMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||
with self.subTest(msg="DetPMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
||||
with self.subTest(msg="ProMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
|
||||
|
||||
def test_dmc_environment_determinism(self):
|
||||
"""Tests that identical seeds produce identical trajectories for DMC MP Envs."""
|
||||
with self.subTest(msg="DMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||
with self.subTest(msg="DetPMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
||||
with self.subTest(msg="ProMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
|
||||
|
||||
def test_metaworld_environment_determinism(self):
|
||||
"""Tests that identical seeds produce identical trajectories for Metaworld MP Envs."""
|
||||
with self.subTest(msg="DMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||
with self.subTest(msg="DetPMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
||||
with self.subTest(msg="ProMP"):
|
||||
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -81,13 +81,13 @@ class TestStepMetaWorlEnvironments(unittest.TestCase):
|
||||
def _verify_done(self, done):
|
||||
self.assertIsInstance(done, bool, f"Returned {done} as done flag, expected bool.")
|
||||
|
||||
def test_dmc_functionality(self):
|
||||
def test_metaworld_functionality(self):
|
||||
"""Tests that environments runs without errors using random actions."""
|
||||
for env_id in ALL_ENVS:
|
||||
with self.subTest(msg=env_id):
|
||||
self._run_env(env_id)
|
||||
|
||||
def test_dmc_determinism(self):
|
||||
def test_metaworld_determinism(self):
|
||||
"""Tests that identical seeds produce identical trajectories."""
|
||||
seed = 0
|
||||
# Iterate over two trajectories, which should have the same state and action sequence
|
||||
|
Loading…
Reference in New Issue
Block a user