commit
04d27426ba
@ -81,7 +81,7 @@ trajectory.
|
|||||||
```python
|
```python
|
||||||
import alr_envs
|
import alr_envs
|
||||||
|
|
||||||
env = alr_envs.make('HoleReacherDetPMP-v0', seed=1)
|
env = alr_envs.make('HoleReacherProMP-v0', seed=1)
|
||||||
# render() can be called once in the beginning with all necessary arguments. To turn it of again just call render(None).
|
# render() can be called once in the beginning with all necessary arguments. To turn it of again just call render(None).
|
||||||
env.render()
|
env.render()
|
||||||
|
|
||||||
@ -95,7 +95,7 @@ for i in range(5):
|
|||||||
```
|
```
|
||||||
|
|
||||||
To show all available environments, we provide some additional convenience. Each value will return a dictionary with two
|
To show all available environments, we provide some additional convenience. Each value will return a dictionary with two
|
||||||
keys `DMP` and `DetPMP` that store a list of available environment names.
|
keys `DMP` and `ProMP` that store a list of available environment names.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import alr_envs
|
import alr_envs
|
||||||
@ -193,7 +193,7 @@ mp_kwargs = {...}
|
|||||||
kwargs = {...}
|
kwargs = {...}
|
||||||
env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs)
|
env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
# OR for a deterministic ProMP (other mp_kwargs are required):
|
# OR for a deterministic ProMP (other mp_kwargs are required):
|
||||||
# env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from alr_envs import dmc, meta, open_ai
|
from alr_envs import dmc, meta, open_ai
|
||||||
from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_promp_env, make_rank
|
from alr_envs.utils.make_env_helpers import make, make_dmp_env, make_promp_env, make_rank
|
||||||
from alr_envs.utils import make_dmc
|
from alr_envs.utils import make_dmc
|
||||||
|
|
||||||
# Convenience function for all MP environments
|
# Convenience function for all MP environments
|
||||||
|
@ -10,7 +10,9 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
|||||||
from .mujoco.reacher.alr_reacher import ALRReacherEnv
|
from .mujoco.reacher.alr_reacher import ALRReacherEnv
|
||||||
from .mujoco.reacher.balancing import BalancingEnv
|
from .mujoco.reacher.balancing import BalancingEnv
|
||||||
|
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
from alr_envs.alr.mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS
|
||||||
|
|
||||||
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||||
|
|
||||||
# Classic Control
|
# Classic Control
|
||||||
## Simple Reacher
|
## Simple Reacher
|
||||||
@ -195,16 +197,20 @@ register(
|
|||||||
)
|
)
|
||||||
|
|
||||||
## Table Tennis
|
## Table Tennis
|
||||||
from alr_envs.alr.mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS
|
|
||||||
register(id='TableTennis2DCtxt-v0',
|
register(id='TableTennis2DCtxt-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:TT_Env_Gym',
|
entry_point='alr_envs.alr.mujoco:TTEnvGym',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS,
|
max_episode_steps=MAX_EPISODE_STEPS,
|
||||||
kwargs={'ctxt_dim':2})
|
kwargs={'ctxt_dim': 2})
|
||||||
|
|
||||||
|
register(id='TableTennis2DCtxt-v1',
|
||||||
|
entry_point='alr_envs.alr.mujoco:TTEnvGym',
|
||||||
|
max_episode_steps=MAX_EPISODE_STEPS,
|
||||||
|
kwargs={'ctxt_dim': 2, 'fixed_goal': True})
|
||||||
|
|
||||||
register(id='TableTennis4DCtxt-v0',
|
register(id='TableTennis4DCtxt-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:TT_Env_Gym',
|
entry_point='alr_envs.alr.mujoco:TTEnvGym',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS,
|
max_episode_steps=MAX_EPISODE_STEPS,
|
||||||
kwargs={'ctxt_dim':4})
|
kwargs={'ctxt_dim': 4})
|
||||||
|
|
||||||
## BeerPong
|
## BeerPong
|
||||||
difficulties = ["simple", "intermediate", "hard", "hardest"]
|
difficulties = ["simple", "intermediate", "hard", "hardest"]
|
||||||
@ -240,8 +246,12 @@ for _v in _versions:
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"alpha_phase": 2,
|
"alpha_phase": 2,
|
||||||
"learn_goal": True,
|
"learn_goal": True,
|
||||||
"policy_type": "velocity",
|
"policy_type": "motor",
|
||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": .6,
|
||||||
|
"d_gains": .075
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -260,33 +270,16 @@ for _v in _versions:
|
|||||||
"duration": 2,
|
"duration": 2,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 1,
|
"weights_scale": 1,
|
||||||
"zero_start": True
|
"zero_start": True,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": .6,
|
||||||
|
"d_gains": .075
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'{_name[0]}DetPMP-{_name[1]}'
|
|
||||||
register(
|
|
||||||
id=_env_id,
|
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
|
||||||
# max_episode_steps=1,
|
|
||||||
kwargs={
|
|
||||||
"name": f"alr_envs:{_v}",
|
|
||||||
"wrappers": [classic_control.simple_reacher.MPWrapper],
|
|
||||||
"mp_kwargs": {
|
|
||||||
"num_dof": 2 if "long" not in _v.lower() else 5,
|
|
||||||
"num_basis": 5,
|
|
||||||
"duration": 2,
|
|
||||||
"width": 0.025,
|
|
||||||
"policy_type": "velocity",
|
|
||||||
"weights_scale": 0.2,
|
|
||||||
"zero_start": True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
|
||||||
|
|
||||||
# Viapoint reacher
|
# Viapoint reacher
|
||||||
register(
|
register(
|
||||||
id='ViaPointReacherDMP-v0',
|
id='ViaPointReacherDMP-v0',
|
||||||
@ -318,7 +311,7 @@ register(
|
|||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
"policy_type": "motor",
|
"policy_type": "velocity",
|
||||||
"weights_scale": 1,
|
"weights_scale": 1,
|
||||||
"zero_start": True
|
"zero_start": True
|
||||||
}
|
}
|
||||||
@ -326,26 +319,6 @@ register(
|
|||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
|
||||||
|
|
||||||
register(
|
|
||||||
id='ViaPointReacherDetPMP-v0',
|
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
|
||||||
# max_episode_steps=1,
|
|
||||||
kwargs={
|
|
||||||
"name": "alr_envs:ViaPointReacher-v0",
|
|
||||||
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
|
|
||||||
"mp_kwargs": {
|
|
||||||
"num_dof": 5,
|
|
||||||
"num_basis": 5,
|
|
||||||
"duration": 2,
|
|
||||||
"width": 0.025,
|
|
||||||
"policy_type": "velocity",
|
|
||||||
"weights_scale": 0.2,
|
|
||||||
"zero_start": True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ViaPointReacherDetPMP-v0")
|
|
||||||
|
|
||||||
## Hole Reacher
|
## Hole Reacher
|
||||||
_versions = ["v0", "v1", "v2"]
|
_versions = ["v0", "v1", "v2"]
|
||||||
for _v in _versions:
|
for _v in _versions:
|
||||||
@ -391,71 +364,77 @@ for _v in _versions:
|
|||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'HoleReacherDetPMP-{_v}'
|
## Beerpong
|
||||||
|
_versions = ["v0", "v1", "v2", "v3"]
|
||||||
|
for _v in _versions:
|
||||||
|
_env_id = f'BeerpongProMP-{_v}'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:HoleReacher-{_v}",
|
"name": f"alr_envs:ALRBeerPong-{_v}",
|
||||||
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
"wrappers": [mujoco.beerpong.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 5,
|
"num_dof": 7,
|
||||||
"num_basis": 5,
|
"num_basis": 2,
|
||||||
"duration": 2,
|
"duration": 1,
|
||||||
"width": 0.025,
|
"post_traj_time": 2,
|
||||||
"policy_type": "velocity",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 1,
|
||||||
"zero_start": True
|
"zero_start": True,
|
||||||
|
"zero_goal": False,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": np.array([ 1.5, 5, 2.55, 3, 2., 2, 1.25]),
|
||||||
|
"d_gains": np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
## Beerpong
|
|
||||||
register(
|
|
||||||
id='BeerpongDetPMP-v0',
|
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
|
||||||
kwargs={
|
|
||||||
"name": "alr_envs:ALRBeerPong-v0",
|
|
||||||
"wrappers": [mujoco.beerpong.MPWrapper],
|
|
||||||
"mp_kwargs": {
|
|
||||||
"num_dof": 7,
|
|
||||||
"num_basis": 2,
|
|
||||||
"n_zero_bases": 2,
|
|
||||||
"duration": 0.5,
|
|
||||||
"post_traj_time": 2.5,
|
|
||||||
"width": 0.01,
|
|
||||||
"off": 0.01,
|
|
||||||
"policy_type": "motor",
|
|
||||||
"weights_scale": 0.08,
|
|
||||||
"zero_start": True,
|
|
||||||
"zero_goal": False,
|
|
||||||
"policy_kwargs": {
|
|
||||||
"p_gains": np.array([ 1.5, 5, 2.55, 3, 2., 2, 1.25]),
|
|
||||||
"d_gains": np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("BeerpongDetPMP-v0")
|
|
||||||
|
|
||||||
## Table Tennis
|
## Table Tennis
|
||||||
|
ctxt_dim = [2, 4]
|
||||||
|
for _v, cd in enumerate(ctxt_dim):
|
||||||
|
_env_id = f'TableTennisProMP-v{_v}'
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": "alr_envs:TableTennis{}DCtxt-v0".format(cd),
|
||||||
|
"wrappers": [mujoco.table_tennis.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 7,
|
||||||
|
"num_basis": 2,
|
||||||
|
"duration": 1.25,
|
||||||
|
"post_traj_time": 4.5,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 1.0,
|
||||||
|
"zero_start": True,
|
||||||
|
"zero_goal": False,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": 0.5*np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]),
|
||||||
|
"d_gains": 0.5*np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='TableTennisDetPMP-v0',
|
id='TableTennisProMP-v2',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:TableTennis4DCtxt-v0",
|
"name": "alr_envs:TableTennis2DCtxt-v1",
|
||||||
"wrappers": [mujoco.table_tennis.MPWrapper],
|
"wrappers": [mujoco.table_tennis.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 7,
|
"num_dof": 7,
|
||||||
"num_basis": 2,
|
"num_basis": 2,
|
||||||
"n_zero_bases": 2,
|
"duration": 1.,
|
||||||
"duration": 1.25,
|
"post_traj_time": 2.5,
|
||||||
"post_traj_time": 4.5,
|
|
||||||
"width": 0.01,
|
|
||||||
"off": 0.01,
|
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 1.0,
|
"weights_scale": 1,
|
||||||
|
"off": -0.05,
|
||||||
|
"bandwidth_factor": 3.5,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"zero_goal": False,
|
"zero_goal": False,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
@ -465,4 +444,4 @@ register(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("TableTennisDetPMP-v0")
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("TableTennisProMP-v2")
|
||||||
|
@ -14,8 +14,5 @@
|
|||||||
|`ViaPointReacherDMP-v0`| A DMP provides a trajectory for the `ViaPointReacher-v0` task. | 200 | 25
|
|`ViaPointReacherDMP-v0`| A DMP provides a trajectory for the `ViaPointReacher-v0` task. | 200 | 25
|
||||||
|`HoleReacherFixedGoalDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task with a fixed goal attractor. | 200 | 25
|
|`HoleReacherFixedGoalDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task with a fixed goal attractor. | 200 | 25
|
||||||
|`HoleReacherDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task. The goal attractor needs to be learned. | 200 | 30
|
|`HoleReacherDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task. The goal attractor needs to be learned. | 200 | 30
|
||||||
|`ALRBallInACupSimpleDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupSimple-v0` task where only 3 joints are actuated. | 4000 | 15
|
|
||||||
|`ALRBallInACupDMP-v0`| A DMP provides a trajectory for the `ALRBallInACup-v0` task. | 4000 | 35
|
|
||||||
|`ALRBallInACupGoalDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupGoal-v0` task. | 4000 | 35 | 3
|
|
||||||
|
|
||||||
[//]: |`HoleReacherDetPMP-v0`|
|
[//]: |`HoleReacherProMPP-v0`|
|
@ -5,7 +5,6 @@ import matplotlib.pyplot as plt
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from alr_envs.alr.classic_control.utils import check_self_collision
|
|
||||||
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,5 +2,5 @@ from .reacher.alr_reacher import ALRReacherEnv
|
|||||||
from .reacher.balancing import BalancingEnv
|
from .reacher.balancing import BalancingEnv
|
||||||
from .ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
from .ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
||||||
from .ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
from .ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
||||||
from .table_tennis.tt_gym import TT_Env_Gym
|
from .table_tennis.tt_gym import TTEnvGym
|
||||||
from .beerpong.beerpong import ALRBeerBongEnv
|
from .beerpong.beerpong import ALRBeerBongEnv
|
@ -1,107 +0,0 @@
|
|||||||
from alr_envs.alr.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
|
||||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
|
||||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
|
||||||
|
|
||||||
|
|
||||||
def make_contextual_env(rank, seed=0):
|
|
||||||
"""
|
|
||||||
Utility function for multiprocessed env.
|
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
env = ALRBallInACupEnv(reward_type="contextual_goal")
|
|
||||||
|
|
||||||
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
|
||||||
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
|
|
||||||
|
|
||||||
env.seed(seed + rank)
|
|
||||||
return env
|
|
||||||
|
|
||||||
return _init
|
|
||||||
|
|
||||||
|
|
||||||
def _make_env(rank, seed=0):
|
|
||||||
"""
|
|
||||||
Utility function for multiprocessed env.
|
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
env = ALRBallInACupEnv(reward_type="simple")
|
|
||||||
|
|
||||||
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
|
||||||
policy_type="motor", weights_scale=0.2, zero_start=True, zero_goal=True)
|
|
||||||
|
|
||||||
env.seed(seed + rank)
|
|
||||||
return env
|
|
||||||
|
|
||||||
return _init
|
|
||||||
|
|
||||||
|
|
||||||
def make_simple_env(rank, seed=0):
|
|
||||||
"""
|
|
||||||
Utility function for multiprocessed env.
|
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
env = ALRBallInACupEnv(reward_type="simple")
|
|
||||||
|
|
||||||
env = DetPMPWrapper(env, num_dof=3, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
|
||||||
policy_type="motor", weights_scale=0.25, zero_start=True, zero_goal=True, off=-0.1)
|
|
||||||
|
|
||||||
env.seed(seed + rank)
|
|
||||||
return env
|
|
||||||
|
|
||||||
return _init
|
|
||||||
|
|
||||||
|
|
||||||
def make_simple_dmp_env(rank, seed=0):
|
|
||||||
"""
|
|
||||||
Utility function for multiprocessed env.
|
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
_env = ALRBallInACupEnv(reward_type="simple")
|
|
||||||
|
|
||||||
_env = DmpWrapper(_env,
|
|
||||||
num_dof=3,
|
|
||||||
num_basis=5,
|
|
||||||
duration=3.5,
|
|
||||||
post_traj_time=4.5,
|
|
||||||
bandwidth_factor=2.5,
|
|
||||||
dt=_env.dt,
|
|
||||||
learn_goal=False,
|
|
||||||
alpha_phase=3,
|
|
||||||
start_pos=_env.start_pos[1::2],
|
|
||||||
final_pos=_env.start_pos[1::2],
|
|
||||||
policy_type="motor",
|
|
||||||
weights_scale=100,
|
|
||||||
)
|
|
||||||
|
|
||||||
_env.seed(seed + rank)
|
|
||||||
return _env
|
|
||||||
|
|
||||||
return _init
|
|
@ -27,10 +27,10 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.ball_site_id = 0
|
self.ball_site_id = 0
|
||||||
self.ball_id = 11
|
self.ball_id = 11
|
||||||
|
|
||||||
self._release_step = 100 # time step of ball release
|
self._release_step = 175 # time step of ball release
|
||||||
|
|
||||||
self.sim_time = 4 # seconds
|
self.sim_time = 3 # seconds
|
||||||
self.ep_length = 600 # based on 5 seconds with dt = 0.005 int(self.sim_time / self.dt)
|
self.ep_length = 600 # based on 3 seconds with dt = 0.005 int(self.sim_time / self.dt)
|
||||||
self.cup_table_id = 10
|
self.cup_table_id = 10
|
||||||
|
|
||||||
if noisy:
|
if noisy:
|
||||||
@ -127,24 +127,28 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self._steps += 1
|
self._steps += 1
|
||||||
else:
|
else:
|
||||||
reward = -30
|
reward = -30
|
||||||
|
reward_infos = dict()
|
||||||
success = False
|
success = False
|
||||||
is_collided = False
|
is_collided = False
|
||||||
done = True
|
done = True
|
||||||
ball_pos = np.zeros(3)
|
ball_pos = np.zeros(3)
|
||||||
ball_vel = np.zeros(3)
|
ball_vel = np.zeros(3)
|
||||||
|
|
||||||
return ob, reward, done, dict(reward_dist=reward_dist,
|
infos = dict(reward_dist=reward_dist,
|
||||||
reward_ctrl=reward_ctrl,
|
reward_ctrl=reward_ctrl,
|
||||||
reward=reward,
|
reward=reward,
|
||||||
velocity=angular_vel,
|
velocity=angular_vel,
|
||||||
# traj=self._q_pos,
|
# traj=self._q_pos,
|
||||||
action=a,
|
action=a,
|
||||||
q_pos=self.sim.data.qpos[0:7].ravel().copy(),
|
q_pos=self.sim.data.qpos[0:7].ravel().copy(),
|
||||||
q_vel=self.sim.data.qvel[0:7].ravel().copy(),
|
q_vel=self.sim.data.qvel[0:7].ravel().copy(),
|
||||||
ball_pos=ball_pos,
|
ball_pos=ball_pos,
|
||||||
ball_vel=ball_vel,
|
ball_vel=ball_vel,
|
||||||
is_success=success,
|
success=success,
|
||||||
is_collided=is_collided, sim_crash=crash)
|
is_collided=is_collided, sim_crash=crash)
|
||||||
|
infos.update(reward_infos)
|
||||||
|
|
||||||
|
return ob, reward, done, infos
|
||||||
|
|
||||||
def check_traj_in_joint_limits(self):
|
def check_traj_in_joint_limits(self):
|
||||||
return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min)
|
return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min)
|
||||||
@ -171,7 +175,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env = ALRBeerBongEnv(reward_type="no_context", difficulty='hardest')
|
env = ALRBeerBongEnv(reward_type="staged", difficulty='hardest')
|
||||||
|
|
||||||
# env.configure(ctxt)
|
# env.configure(ctxt)
|
||||||
env.reset()
|
env.reset()
|
||||||
|
@ -71,6 +71,7 @@ class BeerPongReward:
|
|||||||
|
|
||||||
goal_pos = env.sim.data.site_xpos[self.goal_id]
|
goal_pos = env.sim.data.site_xpos[self.goal_id]
|
||||||
ball_pos = env.sim.data.body_xpos[self.ball_id]
|
ball_pos = env.sim.data.body_xpos[self.ball_id]
|
||||||
|
ball_vel = env.sim.data.body_xvelp[self.ball_id]
|
||||||
goal_final_pos = env.sim.data.site_xpos[self.goal_final_id]
|
goal_final_pos = env.sim.data.site_xpos[self.goal_final_id]
|
||||||
self.dists.append(np.linalg.norm(goal_pos - ball_pos))
|
self.dists.append(np.linalg.norm(goal_pos - ball_pos))
|
||||||
self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos))
|
self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos))
|
||||||
@ -131,6 +132,7 @@ class BeerPongReward:
|
|||||||
infos["success"] = success
|
infos["success"] = success
|
||||||
infos["is_collided"] = self._is_collided
|
infos["is_collided"] = self._is_collided
|
||||||
infos["ball_pos"] = ball_pos.copy()
|
infos["ball_pos"] = ball_pos.copy()
|
||||||
|
infos["ball_vel"] = ball_vel.copy()
|
||||||
infos["action_cost"] = 5e-4 * action_cost
|
infos["action_cost"] = 5e-4 * action_cost
|
||||||
|
|
||||||
return reward, infos
|
return reward, infos
|
||||||
|
@ -81,32 +81,36 @@ class BeerPongReward:
|
|||||||
action_cost = np.sum(np.square(action))
|
action_cost = np.sum(np.square(action))
|
||||||
self.action_costs.append(action_cost)
|
self.action_costs.append(action_cost)
|
||||||
|
|
||||||
|
if not self.ball_table_contact:
|
||||||
|
self.ball_table_contact = self._check_collision_single_objects(env.sim, self.ball_collision_id,
|
||||||
|
self.table_collision_id)
|
||||||
|
|
||||||
self._is_collided = self._check_collision_with_itself(env.sim, self.robot_collision_ids)
|
self._is_collided = self._check_collision_with_itself(env.sim, self.robot_collision_ids)
|
||||||
if env._steps == env.ep_length - 1 or self._is_collided:
|
if env._steps == env.ep_length - 1 or self._is_collided:
|
||||||
|
|
||||||
min_dist = np.min(self.dists)
|
min_dist = np.min(self.dists)
|
||||||
ball_table_bounce = self._check_collision_single_objects(env.sim, self.ball_collision_id,
|
final_dist = self.dists_final[-1]
|
||||||
self.table_collision_id)
|
|
||||||
ball_cup_table_cont = self._check_collision_with_set_of_objects(env.sim, self.ball_collision_id,
|
|
||||||
self.cup_collision_ids)
|
|
||||||
ball_wall_cont = self._check_collision_single_objects(env.sim, self.ball_collision_id,
|
|
||||||
self.wall_collision_id)
|
|
||||||
ball_in_cup = self._check_collision_single_objects(env.sim, self.ball_collision_id,
|
ball_in_cup = self._check_collision_single_objects(env.sim, self.ball_collision_id,
|
||||||
self.cup_table_collision_id)
|
self.cup_table_collision_id)
|
||||||
if not ball_in_cup:
|
|
||||||
cost_offset = 2
|
|
||||||
if not ball_cup_table_cont and not ball_table_bounce and not ball_wall_cont:
|
|
||||||
cost_offset += 2
|
|
||||||
cost = cost_offset + min_dist ** 2 + 0.5 * self.dists_final[-1] ** 2 + 1e-7 * action_cost
|
|
||||||
else:
|
|
||||||
cost = self.dists_final[-1] ** 2 + 1.5 * action_cost * 1e-7
|
|
||||||
|
|
||||||
reward = - 1 * cost - self.collision_penalty * int(self._is_collided)
|
# encourage bounce before falling into cup
|
||||||
|
if not ball_in_cup:
|
||||||
|
if not self.ball_table_contact:
|
||||||
|
reward = 0.2 * (1 - np.tanh(min_dist ** 2)) + 0.1 * (1 - np.tanh(final_dist ** 2))
|
||||||
|
else:
|
||||||
|
reward = (1 - np.tanh(min_dist ** 2)) + 0.5 * (1 - np.tanh(final_dist ** 2))
|
||||||
|
else:
|
||||||
|
if not self.ball_table_contact:
|
||||||
|
reward = 2 * (1 - np.tanh(final_dist ** 2)) + 1 * (1 - np.tanh(min_dist ** 2)) + 1
|
||||||
|
else:
|
||||||
|
reward = 2 * (1 - np.tanh(final_dist ** 2)) + 1 * (1 - np.tanh(min_dist ** 2)) + 3
|
||||||
|
|
||||||
|
# reward = - 1 * cost - self.collision_penalty * int(self._is_collided)
|
||||||
success = ball_in_cup
|
success = ball_in_cup
|
||||||
crash = self._is_collided
|
crash = self._is_collided
|
||||||
else:
|
else:
|
||||||
reward = - 1e-7 * action_cost
|
reward = - 1e-2 * action_cost
|
||||||
cost = 0
|
|
||||||
success = False
|
success = False
|
||||||
crash = False
|
crash = False
|
||||||
|
|
||||||
@ -115,26 +119,11 @@ class BeerPongReward:
|
|||||||
infos["is_collided"] = self._is_collided
|
infos["is_collided"] = self._is_collided
|
||||||
infos["ball_pos"] = ball_pos.copy()
|
infos["ball_pos"] = ball_pos.copy()
|
||||||
infos["ball_vel"] = ball_vel.copy()
|
infos["ball_vel"] = ball_vel.copy()
|
||||||
infos["action_cost"] = 5e-4 * action_cost
|
infos["action_cost"] = action_cost
|
||||||
infos["task_cost"] = cost
|
infos["task_reward"] = reward
|
||||||
|
|
||||||
return reward, infos
|
return reward, infos
|
||||||
|
|
||||||
def get_cost_offset(self):
|
|
||||||
if self.ball_ground_contact:
|
|
||||||
return 200
|
|
||||||
|
|
||||||
if not self.ball_table_contact:
|
|
||||||
return 100
|
|
||||||
|
|
||||||
if not self.ball_in_cup:
|
|
||||||
return 50
|
|
||||||
|
|
||||||
if self.ball_in_cup and self.ball_cup_contact and not self.noisy_bp:
|
|
||||||
return 10
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def _check_collision_single_objects(self, sim, id_1, id_2):
|
def _check_collision_single_objects(self, sim, id_1, id_2):
|
||||||
for coni in range(0, sim.data.ncon):
|
for coni in range(0, sim.data.ncon):
|
||||||
con = sim.data.contact[coni]
|
con = sim.data.contact[coni]
|
||||||
|
@ -6,8 +6,6 @@ from gym.envs.mujoco import MujocoEnv
|
|||||||
|
|
||||||
class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
|
class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
|
||||||
def __init__(self, n_substeps=4, apply_gravity_comp=True, reward_function=None):
|
def __init__(self, n_substeps=4, apply_gravity_comp=True, reward_function=None):
|
||||||
utils.EzPickle.__init__(**locals())
|
|
||||||
|
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
||||||
@ -28,15 +26,13 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self.context = None
|
self.context = None
|
||||||
|
|
||||||
MujocoEnv.__init__(self, model_path=self.xml_path, frame_skip=n_substeps)
|
|
||||||
|
|
||||||
# alr_mujoco_env.AlrMujocoEnv.__init__(self,
|
# alr_mujoco_env.AlrMujocoEnv.__init__(self,
|
||||||
# self.xml_path,
|
# self.xml_path,
|
||||||
# apply_gravity_comp=apply_gravity_comp,
|
# apply_gravity_comp=apply_gravity_comp,
|
||||||
# n_substeps=n_substeps)
|
# n_substeps=n_substeps)
|
||||||
|
|
||||||
self.sim_time = 8 # seconds
|
self.sim_time = 8 # seconds
|
||||||
self.sim_steps = int(self.sim_time / self.dt)
|
# self.sim_steps = int(self.sim_time / self.dt)
|
||||||
if reward_function is None:
|
if reward_function is None:
|
||||||
from alr_envs.alr.mujoco.beerpong.beerpong_reward_simple import BeerpongReward
|
from alr_envs.alr.mujoco.beerpong.beerpong_reward_simple import BeerpongReward
|
||||||
reward_function = BeerpongReward
|
reward_function = BeerpongReward
|
||||||
@ -46,6 +42,9 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.cup_table_id = self.sim.model._body_name2id["cup_table"]
|
self.cup_table_id = self.sim.model._body_name2id["cup_table"]
|
||||||
# self.bounce_table_id = self.sim.model._body_name2id["bounce_table"]
|
# self.bounce_table_id = self.sim.model._body_name2id["bounce_table"]
|
||||||
|
|
||||||
|
MujocoEnv.__init__(self, model_path=self.xml_path, frame_skip=n_substeps)
|
||||||
|
utils.EzPickle.__init__(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_pos(self):
|
def current_pos(self):
|
||||||
return self.sim.data.qpos[0:7].copy()
|
return self.sim.data.qpos[0:7].copy()
|
||||||
@ -90,7 +89,7 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
reward_ctrl = - np.square(a).sum()
|
reward_ctrl = - np.square(a).sum()
|
||||||
action_cost = np.sum(np.square(a))
|
action_cost = np.sum(np.square(a))
|
||||||
|
|
||||||
crash = self.do_simulation(a)
|
crash = self.do_simulation(a, self.frame_skip)
|
||||||
joint_cons_viol = self.check_traj_in_joint_limits()
|
joint_cons_viol = self.check_traj_in_joint_limits()
|
||||||
|
|
||||||
self._q_pos.append(self.sim.data.qpos[0:7].ravel().copy())
|
self._q_pos.append(self.sim.data.qpos[0:7].ravel().copy())
|
||||||
|
@ -1,72 +0,0 @@
|
|||||||
from alr_envs.utils.mps.detpmp_wrapper import DetPMPWrapper
|
|
||||||
from alr_envs.alr.mujoco.beerpong.beerpong import ALRBeerpongEnv
|
|
||||||
from alr_envs.alr.mujoco.beerpong.beerpong_simple import ALRBeerpongEnv as ALRBeerpongEnvSimple
|
|
||||||
|
|
||||||
|
|
||||||
def make_contextual_env(rank, seed=0):
|
|
||||||
"""
|
|
||||||
Utility function for multiprocessed env.
|
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
env = ALRBeerpongEnv()
|
|
||||||
|
|
||||||
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
|
||||||
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
|
|
||||||
|
|
||||||
env.seed(seed + rank)
|
|
||||||
return env
|
|
||||||
|
|
||||||
return _init
|
|
||||||
|
|
||||||
|
|
||||||
def _make_env(rank, seed=0):
|
|
||||||
"""
|
|
||||||
Utility function for multiprocessed env.
|
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
env = ALRBeerpongEnvSimple()
|
|
||||||
|
|
||||||
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
|
||||||
policy_type="motor", weights_scale=0.25, zero_start=True, zero_goal=True)
|
|
||||||
|
|
||||||
env.seed(seed + rank)
|
|
||||||
return env
|
|
||||||
|
|
||||||
return _init
|
|
||||||
|
|
||||||
|
|
||||||
def make_simple_env(rank, seed=0):
|
|
||||||
"""
|
|
||||||
Utility function for multiprocessed env.
|
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
env = ALRBeerpongEnvSimple()
|
|
||||||
|
|
||||||
env = DetPMPWrapper(env, num_dof=3, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
|
||||||
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
|
|
||||||
|
|
||||||
env.seed(seed + rank)
|
|
||||||
return env
|
|
||||||
|
|
||||||
return _init
|
|
@ -10,7 +10,7 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward
|
|||||||
|
|
||||||
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
|
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
|
||||||
|
|
||||||
MAX_EPISODE_STEPS = 1375
|
MAX_EPISODE_STEPS = 1750
|
||||||
BALL_NAME_CONTACT = "target_ball_contact"
|
BALL_NAME_CONTACT = "target_ball_contact"
|
||||||
BALL_NAME = "target_ball"
|
BALL_NAME = "target_ball"
|
||||||
TABLE_NAME = 'table_tennis_table'
|
TABLE_NAME = 'table_tennis_table'
|
||||||
@ -22,24 +22,30 @@ RACKET_NAME = 'bat'
|
|||||||
CONTEXT_RANGE_BOUNDS_2DIM = np.array([[-1.2, -0.6], [-0.2, 0.0]])
|
CONTEXT_RANGE_BOUNDS_2DIM = np.array([[-1.2, -0.6], [-0.2, 0.0]])
|
||||||
CONTEXT_RANGE_BOUNDS_4DIM = np.array([[-1.35, -0.75, -1.25, -0.75], [-0.1, 0.75, -0.1, 0.75]])
|
CONTEXT_RANGE_BOUNDS_4DIM = np.array([[-1.35, -0.75, -1.25, -0.75], [-0.1, 0.75, -0.1, 0.75]])
|
||||||
|
|
||||||
class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|
||||||
|
|
||||||
def __init__(self, ctxt_dim=2):
|
class TTEnvGym(MujocoEnv, utils.EzPickle):
|
||||||
|
|
||||||
|
def __init__(self, ctxt_dim=2, fixed_goal=False):
|
||||||
model_path = os.path.join(os.path.dirname(__file__), "xml", 'table_tennis_env.xml')
|
model_path = os.path.join(os.path.dirname(__file__), "xml", 'table_tennis_env.xml')
|
||||||
|
|
||||||
self.ctxt_dim = ctxt_dim
|
self.ctxt_dim = ctxt_dim
|
||||||
|
self.fixed_goal = fixed_goal
|
||||||
if ctxt_dim == 2:
|
if ctxt_dim == 2:
|
||||||
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_2DIM
|
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_2DIM
|
||||||
self.goal = np.zeros(3) # 2 x,y + 1z
|
if self.fixed_goal:
|
||||||
|
self.goal = np.array([-1, -0.1, 0])
|
||||||
|
else:
|
||||||
|
self.goal = np.zeros(3) # 2 x,y + 1z
|
||||||
elif ctxt_dim == 4:
|
elif ctxt_dim == 4:
|
||||||
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_4DIM
|
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_4DIM
|
||||||
self.goal = np.zeros(3)
|
self.goal = np.zeros(3)
|
||||||
else:
|
else:
|
||||||
raise ValueError("either 2 or 4 dimensional Contexts available")
|
raise ValueError("either 2 or 4 dimensional Contexts available")
|
||||||
|
|
||||||
action_space_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
|
# has no effect as it is overwritten in init of super
|
||||||
action_space_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
|
# action_space_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
|
||||||
self.action_space = spaces.Box(low=action_space_low, high=action_space_high, dtype='float64')
|
# action_space_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
|
||||||
|
# self.action_space = spaces.Box(low=action_space_low, high=action_space_high, dtype='float64')
|
||||||
|
|
||||||
self.time_steps = 0
|
self.time_steps = 0
|
||||||
self.init_qpos_tt = np.array([0, 0, 0, 1.5, 0, 0, 1.5, 0, 0, 0])
|
self.init_qpos_tt = np.array([0, 0, 0, 1.5, 0, 0, 1.5, 0, 0, 0])
|
||||||
@ -47,10 +53,10 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self.reward_func = TT_Reward(self.ctxt_dim)
|
self.reward_func = TT_Reward(self.ctxt_dim)
|
||||||
self.ball_landing_pos = None
|
self.ball_landing_pos = None
|
||||||
self.hited_ball = False
|
self.hit_ball = False
|
||||||
self.ball_contact_after_hit = False
|
self.ball_contact_after_hit = False
|
||||||
self._ids_set = False
|
self._ids_set = False
|
||||||
super(TT_Env_Gym, self).__init__(model_path=model_path, frame_skip=1)
|
super(TTEnvGym, self).__init__(model_path=model_path, frame_skip=1)
|
||||||
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
|
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
|
||||||
self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT]
|
self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT]
|
||||||
self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME]
|
self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME]
|
||||||
@ -77,15 +83,18 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|||||||
return obs
|
return obs
|
||||||
|
|
||||||
def sample_context(self):
|
def sample_context(self):
|
||||||
return np.random.uniform(self.context_range_bounds[0], self.context_range_bounds[1], size=self.ctxt_dim)
|
return self.np_random.uniform(self.context_range_bounds[0], self.context_range_bounds[1], size=self.ctxt_dim)
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self.set_state(self.init_qpos_tt, self.init_qvel_tt) # reset to initial sim state
|
self.set_state(self.init_qpos_tt, self.init_qvel_tt) # reset to initial sim state
|
||||||
self.time_steps = 0
|
self.time_steps = 0
|
||||||
self.ball_landing_pos = None
|
self.ball_landing_pos = None
|
||||||
self.hited_ball = False
|
self.hit_ball = False
|
||||||
self.ball_contact_after_hit = False
|
self.ball_contact_after_hit = False
|
||||||
self.goal = self.sample_context()[:2]
|
if self.fixed_goal:
|
||||||
|
self.goal = self.goal[:2]
|
||||||
|
else:
|
||||||
|
self.goal = self.sample_context()[:2]
|
||||||
if self.ctxt_dim == 2:
|
if self.ctxt_dim == 2:
|
||||||
initial_ball_state = ball_init(random=False) # fixed velocity, fixed position
|
initial_ball_state = ball_init(random=False) # fixed velocity, fixed position
|
||||||
elif self.ctxt_dim == 4:
|
elif self.ctxt_dim == 4:
|
||||||
@ -122,12 +131,12 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|||||||
if not self._ids_set:
|
if not self._ids_set:
|
||||||
self._set_ids()
|
self._set_ids()
|
||||||
done = False
|
done = False
|
||||||
episode_end = False if self.time_steps+1<MAX_EPISODE_STEPS else True
|
episode_end = False if self.time_steps + 1 < MAX_EPISODE_STEPS else True
|
||||||
if not self.hited_ball:
|
if not self.hit_ball:
|
||||||
self.hited_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_1) # check for one side
|
self.hit_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_1) # check for one side
|
||||||
if not self.hited_ball:
|
if not self.hit_ball:
|
||||||
self.hited_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_2) # check for other side
|
self.hit_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_2) # check for other side
|
||||||
if self.hited_ball:
|
if self.hit_ball:
|
||||||
if not self.ball_contact_after_hit:
|
if not self.ball_contact_after_hit:
|
||||||
if self._contact_checker(self.ball_contact_id, self.floor_contact_id): # first check contact with floor
|
if self._contact_checker(self.ball_contact_id, self.floor_contact_id): # first check contact with floor
|
||||||
self.ball_contact_after_hit = True
|
self.ball_contact_after_hit = True
|
||||||
@ -140,7 +149,7 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|||||||
if self.ball_landing_pos is not None:
|
if self.ball_landing_pos is not None:
|
||||||
done = True
|
done = True
|
||||||
episode_end =True
|
episode_end =True
|
||||||
reward = self.reward_func.get_reward(episode_end, c_ball_pos, racket_pos, self.hited_ball, self.ball_landing_pos)
|
reward = self.reward_func.get_reward(episode_end, c_ball_pos, racket_pos, self.hit_ball, self.ball_landing_pos)
|
||||||
self.time_steps += 1
|
self.time_steps += 1
|
||||||
# gravity compensation on joints:
|
# gravity compensation on joints:
|
||||||
#action += self.sim.data.qfrc_bias[:7].copy()
|
#action += self.sim.data.qfrc_bias[:7].copy()
|
||||||
@ -151,7 +160,10 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|||||||
done = True
|
done = True
|
||||||
reward = -25
|
reward = -25
|
||||||
ob = self._get_obs()
|
ob = self._get_obs()
|
||||||
return ob, reward, done, {"hit_ball":self.hited_ball}# might add some information here ....
|
info = {"hit_ball": self.hit_ball,
|
||||||
|
"q_pos": np.copy(self.sim.data.qpos[:7]),
|
||||||
|
"ball_pos": np.copy(self.sim.data.qpos[7:])}
|
||||||
|
return ob, reward, done, info # might add some information here ....
|
||||||
|
|
||||||
def set_context(self, context):
|
def set_context(self, context):
|
||||||
old_state = self.sim.get_state()
|
old_state = self.sim.get_state()
|
||||||
|
@ -19,7 +19,7 @@ class TT_Reward:
|
|||||||
# # seems to work for episodic case
|
# # seems to work for episodic case
|
||||||
min_r_b_dist = np.min(np.linalg.norm(np.array(self.c_ball_traj) - np.array(self.c_racket_traj), axis=1))
|
min_r_b_dist = np.min(np.linalg.norm(np.array(self.c_ball_traj) - np.array(self.c_racket_traj), axis=1))
|
||||||
if not hited_ball:
|
if not hited_ball:
|
||||||
return 0.2 * (1- np.tanh(min_r_b_dist**2))
|
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
|
||||||
else:
|
else:
|
||||||
if ball_landing_pos is None:
|
if ball_landing_pos is None:
|
||||||
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self.c_ball_traj)[:,:2] - self.c_goal[:2], axis=1))
|
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self.c_ball_traj)[:,:2] - self.c_goal[:2], axis=1))
|
||||||
|
@ -11,9 +11,9 @@ environments in order to use our Motion Primitive gym interface with them.
|
|||||||
|
|
||||||
|Name| Description|Trajectory Horizon|Action Dimension|Context Dimension
|
|Name| Description|Trajectory Horizon|Action Dimension|Context Dimension
|
||||||
|---|---|---|---|---|
|
|---|---|---|---|---|
|
||||||
|`dmc_ball_in_cup-catch_detpmp-v0`| A DetPmP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000 | 10 | 2
|
|`dmc_ball_in_cup-catch_promp-v0`| A ProMP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000 | 10 | 2
|
||||||
|`dmc_ball_in_cup-catch_dmp-v0`| A DMP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000| 10 | 2
|
|`dmc_ball_in_cup-catch_dmp-v0`| A DMP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000| 10 | 2
|
||||||
|`dmc_reacher-easy_detpmp-v0`| A DetPmP wrapped version of the "easy" task for the "reacher" environment. | 1000 | 10 | 4
|
|`dmc_reacher-easy_promp-v0`| A ProMP wrapped version of the "easy" task for the "reacher" environment. | 1000 | 10 | 4
|
||||||
|`dmc_reacher-easy_dmp-v0`| A DMP wrapped version of the "easy" task for the "reacher" environment. | 1000| 10 | 4
|
|`dmc_reacher-easy_dmp-v0`| A DMP wrapped version of the "easy" task for the "reacher" environment. | 1000| 10 | 4
|
||||||
|`dmc_reacher-hard_detpmp-v0`| A DetPmP wrapped version of the "hard" task for the "reacher" environment.| 1000 | 10 | 4
|
|`dmc_reacher-hard_promp-v0`| A ProMP wrapped version of the "hard" task for the "reacher" environment.| 1000 | 10 | 4
|
||||||
|`dmc_reacher-hard_dmp-v0`| A DMP wrapped version of the "hard" task for the "reacher" environment. | 1000 | 10 | 4
|
|`dmc_reacher-hard_dmp-v0`| A DMP wrapped version of the "hard" task for the "reacher" environment. | 1000 | 10 | 4
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from . import manipulation, suite
|
from . import manipulation, suite
|
||||||
|
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||||
|
|
||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
|
|
||||||
@ -34,8 +34,8 @@ register(
|
|||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=f'dmc_ball_in_cup-catch_detpmp-v0',
|
id=f'dmc_ball_in_cup-catch_promp-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"ball_in_cup-catch",
|
"name": f"ball_in_cup-catch",
|
||||||
"time_limit": 20,
|
"time_limit": 20,
|
||||||
@ -45,7 +45,6 @@ register(
|
|||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 20,
|
||||||
"width": 0.025,
|
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
@ -55,7 +54,7 @@ register(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_ball_in_cup-catch_detpmp-v0")
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=f'dmc_reacher-easy_dmp-v0',
|
id=f'dmc_reacher-easy_dmp-v0',
|
||||||
@ -86,8 +85,8 @@ register(
|
|||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=f'dmc_reacher-easy_detpmp-v0',
|
id=f'dmc_reacher-easy_promp-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"reacher-easy",
|
"name": f"reacher-easy",
|
||||||
"time_limit": 20,
|
"time_limit": 20,
|
||||||
@ -97,7 +96,6 @@ register(
|
|||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 20,
|
||||||
"width": 0.025,
|
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
@ -108,7 +106,7 @@ register(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-easy_detpmp-v0")
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=f'dmc_reacher-hard_dmp-v0',
|
id=f'dmc_reacher-hard_dmp-v0',
|
||||||
@ -139,8 +137,8 @@ register(
|
|||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=f'dmc_reacher-hard_detpmp-v0',
|
id=f'dmc_reacher-hard_promp-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"reacher-hard",
|
"name": f"reacher-hard",
|
||||||
"time_limit": 20,
|
"time_limit": 20,
|
||||||
@ -150,7 +148,6 @@ register(
|
|||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 20,
|
||||||
"width": 0.025,
|
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
@ -161,7 +158,7 @@ register(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-hard_detpmp-v0")
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-hard_promp-v0")
|
||||||
|
|
||||||
_dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
|
_dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
|
||||||
|
|
||||||
@ -196,10 +193,10 @@ for _task in _dmc_cartpole_tasks:
|
|||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'dmc_cartpole-{_task}_detpmp-v0'
|
_env_id = f'dmc_cartpole-{_task}_promp-v0'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-{_task}",
|
"name": f"cartpole-{_task}",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
@ -210,7 +207,6 @@ for _task in _dmc_cartpole_tasks:
|
|||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
"width": 0.025,
|
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
@ -221,7 +217,7 @@ for _task in _dmc_cartpole_tasks:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'dmc_cartpole-two_poles_dmp-v0'
|
_env_id = f'dmc_cartpole-two_poles_dmp-v0'
|
||||||
register(
|
register(
|
||||||
@ -253,10 +249,10 @@ register(
|
|||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'dmc_cartpole-two_poles_detpmp-v0'
|
_env_id = f'dmc_cartpole-two_poles_promp-v0'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-two_poles",
|
"name": f"cartpole-two_poles",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
@ -267,7 +263,6 @@ register(
|
|||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
"width": 0.025,
|
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
@ -278,7 +273,7 @@ register(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'dmc_cartpole-three_poles_dmp-v0'
|
_env_id = f'dmc_cartpole-three_poles_dmp-v0'
|
||||||
register(
|
register(
|
||||||
@ -310,10 +305,10 @@ register(
|
|||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
_env_id = f'dmc_cartpole-three_poles_detpmp-v0'
|
_env_id = f'dmc_cartpole-three_poles_promp-v0'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-three_poles",
|
"name": f"cartpole-three_poles",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
@ -324,7 +319,6 @@ register(
|
|||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
"width": 0.025,
|
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
@ -335,7 +329,7 @@ register(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
# DeepMind Manipulation
|
# DeepMind Manipulation
|
||||||
|
|
||||||
@ -364,8 +358,8 @@ register(
|
|||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=f'dmc_manipulation-reach_site_detpmp-v0',
|
id=f'dmc_manipulation-reach_site_promp-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"manipulation-reach_site_features",
|
"name": f"manipulation-reach_site_features",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
@ -375,11 +369,10 @@ register(
|
|||||||
"num_dof": 9,
|
"num_dof": 9,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
"width": 0.025,
|
|
||||||
"policy_type": "velocity",
|
"policy_type": "velocity",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_manipulation-reach_site_detpmp-v0")
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_manipulation-reach_site_promp-v0")
|
||||||
|
@ -84,7 +84,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
}
|
}
|
||||||
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
# OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples):
|
# OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples):
|
||||||
# env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||||
|
|
||||||
# This renders the full MP trajectory
|
# This renders the full MP trajectory
|
||||||
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||||
@ -128,7 +128,7 @@ if __name__ == '__main__':
|
|||||||
example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render)
|
example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render)
|
||||||
|
|
||||||
# Gym + DMC hybrid task provided in the MP framework
|
# Gym + DMC hybrid task provided in the MP framework
|
||||||
example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=render)
|
example_dmc("dmc_ball_in_cup-catch_promp-v0", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# Custom DMC task
|
# Custom DMC task
|
||||||
# Different seed, because the episode is longer for this example and the name+seed combo is already registered above
|
# Different seed, because the episode is longer for this example and the name+seed combo is already registered above
|
||||||
|
@ -76,7 +76,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
"policy_type": "metaworld", # custom controller type for metaworld environments
|
"policy_type": "metaworld", # custom controller type for metaworld environments
|
||||||
}
|
}
|
||||||
|
|
||||||
env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||||
# OR for a DMP (other mp_kwargs are required, see dmc_examples):
|
# OR for a DMP (other mp_kwargs are required, see dmc_examples):
|
||||||
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ if __name__ == '__main__':
|
|||||||
example_dmc("button-press-v2", seed=10, iterations=500, render=render)
|
example_dmc("button-press-v2", seed=10, iterations=500, render=render)
|
||||||
|
|
||||||
# MP + MetaWorld hybrid task provided in the our framework
|
# MP + MetaWorld hybrid task provided in the our framework
|
||||||
example_dmc("ButtonPressDetPMP-v2", seed=10, iterations=1, render=render)
|
example_dmc("ButtonPressProMP-v2", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# Custom MetaWorld task
|
# Custom MetaWorld task
|
||||||
example_custom_dmc_and_mp(seed=10, iterations=1, render=render)
|
example_custom_dmc_and_mp(seed=10, iterations=1, render=render)
|
||||||
|
@ -126,7 +126,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
}
|
}
|
||||||
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||||
# OR for a deterministic ProMP:
|
# OR for a deterministic ProMP:
|
||||||
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
# env = make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||||
|
|
||||||
if render:
|
if render:
|
||||||
env.render(mode="human")
|
env.render(mode="human")
|
||||||
@ -147,7 +147,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = False
|
||||||
# DMP
|
# DMP
|
||||||
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ def example_mp(env_name, seed=1):
|
|||||||
Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered.
|
Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered.
|
||||||
For more information on motion primitive specific stuff, look at the mp examples.
|
For more information on motion primitive specific stuff, look at the mp examples.
|
||||||
Args:
|
Args:
|
||||||
env_name: DetPMP env_id
|
env_name: ProMP env_id
|
||||||
seed: seed
|
seed: seed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -35,7 +35,7 @@ if __name__ == '__main__':
|
|||||||
# example_mp("ReacherDMP-v2")
|
# example_mp("ReacherDMP-v2")
|
||||||
|
|
||||||
# DetProMP
|
# DetProMP
|
||||||
example_mp("ContinuousMountainCarDetPMP-v0")
|
example_mp("ContinuousMountainCarProMP-v0")
|
||||||
example_mp("ReacherDetPMP-v2")
|
example_mp("ReacherProMP-v2")
|
||||||
example_mp("FetchReachDenseDetPMP-v1")
|
example_mp("FetchReachDenseProMP-v1")
|
||||||
example_mp("FetchSlideDenseDetPMP-v1")
|
example_mp("FetchSlideDenseProMP-v1")
|
||||||
|
@ -2,7 +2,7 @@ import numpy as np
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from alr_envs import dmc, meta
|
from alr_envs import dmc, meta
|
||||||
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
from alr_envs.utils.make_env_helpers import make_promp_env
|
||||||
|
|
||||||
# This might work for some environments, however, please verify either way the correct trajectory information
|
# This might work for some environments, however, please verify either way the correct trajectory information
|
||||||
# for your environment are extracted below
|
# for your environment are extracted below
|
||||||
@ -26,8 +26,8 @@ mp_kwargs = {
|
|||||||
|
|
||||||
kwargs = dict(time_limit=2, episode_length=100)
|
kwargs = dict(time_limit=2, episode_length=100)
|
||||||
|
|
||||||
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
# Plot difference between real trajectory and target MP trajectory
|
# Plot difference between real trajectory and target MP trajectory
|
||||||
env.reset()
|
env.reset()
|
||||||
|
@ -3,7 +3,7 @@ from gym import register
|
|||||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||||
object_change_mp_wrapper
|
object_change_mp_wrapper
|
||||||
|
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||||
|
|
||||||
# MetaWorld
|
# MetaWorld
|
||||||
|
|
||||||
@ -12,10 +12,10 @@ _goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "pl
|
|||||||
for _task in _goal_change_envs:
|
for _task in _goal_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": _task,
|
"name": _task,
|
||||||
"wrappers": [goal_change_mp_wrapper.MPWrapper],
|
"wrappers": [goal_change_mp_wrapper.MPWrapper],
|
||||||
@ -24,22 +24,21 @@ for _task in _goal_change_envs:
|
|||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 6.25,
|
"duration": 6.25,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.025,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "metaworld",
|
"policy_type": "metaworld",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
||||||
for _task in _object_change_envs:
|
for _task in _object_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": _task,
|
"name": _task,
|
||||||
"wrappers": [object_change_mp_wrapper.MPWrapper],
|
"wrappers": [object_change_mp_wrapper.MPWrapper],
|
||||||
@ -48,13 +47,12 @@ for _task in _object_change_envs:
|
|||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 6.25,
|
"duration": 6.25,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.025,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "metaworld",
|
"policy_type": "metaworld",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
||||||
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
||||||
@ -70,10 +68,10 @@ _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press
|
|||||||
for _task in _goal_and_object_change_envs:
|
for _task in _goal_and_object_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": _task,
|
"name": _task,
|
||||||
"wrappers": [goal_object_change_mp_wrapper.MPWrapper],
|
"wrappers": [goal_object_change_mp_wrapper.MPWrapper],
|
||||||
@ -82,22 +80,21 @@ for _task in _goal_and_object_change_envs:
|
|||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 6.25,
|
"duration": 6.25,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.025,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "metaworld",
|
"policy_type": "metaworld",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
||||||
for _task in _goal_and_endeffector_change_envs:
|
for _task in _goal_and_endeffector_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
_env_id = f'{name}DetPMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": _task,
|
"name": _task,
|
||||||
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
|
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
|
||||||
@ -106,10 +103,9 @@ for _task in _goal_and_endeffector_change_envs:
|
|||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 6.25,
|
"duration": 6.25,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.025,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "metaworld",
|
"policy_type": "metaworld",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
|
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
@ -8,7 +8,7 @@ These environments are wrapped-versions of their OpenAI-gym counterparts.
|
|||||||
|
|
||||||
|Name| Description|Trajectory Horizon|Action Dimension|Context Dimension
|
|Name| Description|Trajectory Horizon|Action Dimension|Context Dimension
|
||||||
|---|---|---|---|---|
|
|---|---|---|---|---|
|
||||||
|`ContinuousMountainCarDetPMP-v0`| A DetPmP wrapped version of the ContinuousMountainCar-v0 environment. | 100 | 1
|
|`ContinuousMountainCarProMP-v0`| A ProMP wrapped version of the ContinuousMountainCar-v0 environment. | 100 | 1
|
||||||
|`ReacherDetPMP-v2`| A DetPmP wrapped version of the Reacher-v2 environment. | 50 | 2
|
|`ReacherProMP-v2`| A ProMP wrapped version of the Reacher-v2 environment. | 50 | 2
|
||||||
|`FetchSlideDenseDetPMP-v1`| A DetPmP wrapped version of the FetchSlideDense-v1 environment. | 50 | 4
|
|`FetchSlideDenseProMP-v1`| A ProMP wrapped version of the FetchSlideDense-v1 environment. | 50 | 4
|
||||||
|`FetchReachDenseDetPMP-v1`| A DetPmP wrapped version of the FetchReachDense-v1 environment. | 50 | 4
|
|`FetchReachDenseProMP-v1`| A ProMP wrapped version of the FetchReachDense-v1 environment. | 50 | 4
|
||||||
|
@ -3,7 +3,7 @@ from gym.wrappers import FlattenObservation
|
|||||||
|
|
||||||
from . import classic_control, mujoco, robotics
|
from . import classic_control, mujoco, robotics
|
||||||
|
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []}
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||||
|
|
||||||
# Short Continuous Mountain Car
|
# Short Continuous Mountain Car
|
||||||
register(
|
register(
|
||||||
@ -16,8 +16,8 @@ register(
|
|||||||
# Open AI
|
# Open AI
|
||||||
# Classic Control
|
# Classic Control
|
||||||
register(
|
register(
|
||||||
id='ContinuousMountainCarDetPMP-v1',
|
id='ContinuousMountainCarProMP-v1',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:MountainCarContinuous-v1",
|
"name": "alr_envs:MountainCarContinuous-v1",
|
||||||
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
||||||
@ -26,7 +26,6 @@ register(
|
|||||||
"num_basis": 4,
|
"num_basis": 4,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
@ -36,11 +35,11 @@ register(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ContinuousMountainCarDetPMP-v1")
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ContinuousMountainCarProMP-v1")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ContinuousMountainCarDetPMP-v0',
|
id='ContinuousMountainCarProMP-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
||||||
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
||||||
@ -49,7 +48,6 @@ register(
|
|||||||
"num_basis": 4,
|
"num_basis": 4,
|
||||||
"duration": 19.98,
|
"duration": 19.98,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
@ -59,11 +57,11 @@ register(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ContinuousMountainCarDetPMP-v0")
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ContinuousMountainCarProMP-v0")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ReacherDetPMP-v2',
|
id='ReacherProMP-v2',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.mujoco:Reacher-v2",
|
"name": "gym.envs.mujoco:Reacher-v2",
|
||||||
"wrappers": [mujoco.reacher_v2.MPWrapper],
|
"wrappers": [mujoco.reacher_v2.MPWrapper],
|
||||||
@ -72,7 +70,6 @@ register(
|
|||||||
"num_basis": 6,
|
"num_basis": 6,
|
||||||
"duration": 1,
|
"duration": 1,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
@ -82,11 +79,11 @@ register(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ReacherDetPMP-v2")
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ReacherProMP-v2")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='FetchSlideDenseDetPMP-v1',
|
id='FetchSlideDenseProMP-v1',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
||||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||||
@ -95,17 +92,16 @@ register(
|
|||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "position"
|
"policy_type": "position"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchSlideDenseDetPMP-v1")
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchSlideDenseProMP-v1")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='FetchSlideDetPMP-v1',
|
id='FetchSlideProMP-v1',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchSlide-v1",
|
"name": "gym.envs.robotics:FetchSlide-v1",
|
||||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||||
@ -114,17 +110,16 @@ register(
|
|||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "position"
|
"policy_type": "position"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchSlideDetPMP-v1")
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchSlideProMP-v1")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='FetchReachDenseDetPMP-v1',
|
id='FetchReachDenseProMP-v1',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchReachDense-v1",
|
"name": "gym.envs.robotics:FetchReachDense-v1",
|
||||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||||
@ -133,17 +128,16 @@ register(
|
|||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "position"
|
"policy_type": "position"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchReachDenseDetPMP-v1")
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchReachDenseProMP-v1")
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='FetchReachDetPMP-v1',
|
id='FetchReachProMP-v1',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchReach-v1",
|
"name": "gym.envs.robotics:FetchReach-v1",
|
||||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||||
@ -152,10 +146,9 @@ register(
|
|||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
"post_traj_time": 0,
|
"post_traj_time": 0,
|
||||||
"width": 0.02,
|
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_type": "position"
|
"policy_type": "position"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchReachDetPMP-v1")
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchReachProMP-v1")
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from typing import Iterable, Type, Union
|
from typing import Iterable, Type, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@ -5,7 +6,6 @@ import numpy as np
|
|||||||
from gym.envs.registration import EnvSpec
|
from gym.envs.registration import EnvSpec
|
||||||
|
|
||||||
from mp_env_api import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
|
||||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||||
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
|
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
|
||||||
|
|
||||||
@ -48,6 +48,11 @@ def make(env_id: str, seed, **kwargs):
|
|||||||
Returns: Gym environment
|
Returns: Gym environment
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if any([det_pmp in env_id for det_pmp in ["DetPMP", "detpmp"]]):
|
||||||
|
warnings.warn("DetPMP is deprecated and converted to ProMP")
|
||||||
|
env_id = env_id.replace("DetPMP", "ProMP")
|
||||||
|
env_id = env_id.replace("detpmp", "promp")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Add seed to kwargs in case it is a predefined gym+dmc hybrid environment.
|
# Add seed to kwargs in case it is a predefined gym+dmc hybrid environment.
|
||||||
if env_id.startswith("dmc"):
|
if env_id.startswith("dmc"):
|
||||||
@ -153,26 +158,6 @@ def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwar
|
|||||||
return ProMPWrapper(_env, **mp_kwargs)
|
return ProMPWrapper(_env, **mp_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
|
||||||
"""
|
|
||||||
This can also be used standalone for manually building a custom Det ProMP environment.
|
|
||||||
Args:
|
|
||||||
env_id: base_env_name,
|
|
||||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
|
||||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
|
||||||
|
|
||||||
Returns: Det ProMP wrapped gym env
|
|
||||||
|
|
||||||
"""
|
|
||||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
|
||||||
|
|
||||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
|
||||||
|
|
||||||
_verify_dof(_env, mp_kwargs.get("num_dof"))
|
|
||||||
|
|
||||||
return DetPMPWrapper(_env, **mp_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def make_dmp_env_helper(**kwargs):
|
def make_dmp_env_helper(**kwargs):
|
||||||
"""
|
"""
|
||||||
Helper function for registering a DMP gym environments.
|
Helper function for registering a DMP gym environments.
|
||||||
@ -212,26 +197,6 @@ def make_promp_env_helper(**kwargs):
|
|||||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_detpmp_env_helper(**kwargs):
|
|
||||||
"""
|
|
||||||
Helper function for registering ProMP gym environments.
|
|
||||||
This can also be used standalone for manually building a custom ProMP environment.
|
|
||||||
Args:
|
|
||||||
**kwargs: expects at least the following:
|
|
||||||
{
|
|
||||||
"name": base_env_name,
|
|
||||||
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
|
||||||
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
|
|
||||||
}
|
|
||||||
|
|
||||||
Returns: DMP wrapped gym env
|
|
||||||
|
|
||||||
"""
|
|
||||||
seed = kwargs.pop("seed", None)
|
|
||||||
return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
|
||||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
|
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
|
||||||
"""
|
"""
|
||||||
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
|
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
|
||||||
|
@ -98,8 +98,8 @@ class TestMPEnvironments(unittest.TestCase):
|
|||||||
with self.subTest(msg=env_id):
|
with self.subTest(msg=env_id):
|
||||||
self._run_env(env_id)
|
self._run_env(env_id)
|
||||||
|
|
||||||
with self.subTest(msg="DetPMP"):
|
with self.subTest(msg="ProMP"):
|
||||||
for env_id in alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
for env_id in alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
|
||||||
with self.subTest(msg=env_id):
|
with self.subTest(msg=env_id):
|
||||||
self._run_env(env_id)
|
self._run_env(env_id)
|
||||||
|
|
||||||
@ -110,8 +110,8 @@ class TestMPEnvironments(unittest.TestCase):
|
|||||||
with self.subTest(msg=env_id):
|
with self.subTest(msg=env_id):
|
||||||
self._run_env(env_id)
|
self._run_env(env_id)
|
||||||
|
|
||||||
with self.subTest(msg="DetPMP"):
|
with self.subTest(msg="ProMP"):
|
||||||
for env_id in alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
for env_id in alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
|
||||||
with self.subTest(msg=env_id):
|
with self.subTest(msg=env_id):
|
||||||
self._run_env(env_id)
|
self._run_env(env_id)
|
||||||
|
|
||||||
@ -122,8 +122,8 @@ class TestMPEnvironments(unittest.TestCase):
|
|||||||
with self.subTest(msg=env_id):
|
with self.subTest(msg=env_id):
|
||||||
self._run_env(env_id)
|
self._run_env(env_id)
|
||||||
|
|
||||||
with self.subTest(msg="DetPMP"):
|
with self.subTest(msg="ProMP"):
|
||||||
for env_id in alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
for env_id in alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
|
||||||
with self.subTest(msg=env_id):
|
with self.subTest(msg=env_id):
|
||||||
self._run_env(env_id)
|
self._run_env(env_id)
|
||||||
|
|
||||||
@ -134,8 +134,8 @@ class TestMPEnvironments(unittest.TestCase):
|
|||||||
with self.subTest(msg=env_id):
|
with self.subTest(msg=env_id):
|
||||||
self._run_env(env_id)
|
self._run_env(env_id)
|
||||||
|
|
||||||
with self.subTest(msg="DetPMP"):
|
with self.subTest(msg="ProMP"):
|
||||||
for env_id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
for env_id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
|
||||||
with self.subTest(msg=env_id):
|
with self.subTest(msg=env_id):
|
||||||
self._run_env(env_id)
|
self._run_env(env_id)
|
||||||
|
|
||||||
@ -143,29 +143,29 @@ class TestMPEnvironments(unittest.TestCase):
|
|||||||
"""Tests that identical seeds produce identical trajectories for ALR MP Envs."""
|
"""Tests that identical seeds produce identical trajectories for ALR MP Envs."""
|
||||||
with self.subTest(msg="DMP"):
|
with self.subTest(msg="DMP"):
|
||||||
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||||
with self.subTest(msg="DetPMP"):
|
with self.subTest(msg="ProMP"):
|
||||||
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
|
||||||
|
|
||||||
def test_openai_environment_determinism(self):
|
def test_openai_environment_determinism(self):
|
||||||
"""Tests that identical seeds produce identical trajectories for OpenAI gym MP Envs."""
|
"""Tests that identical seeds produce identical trajectories for OpenAI gym MP Envs."""
|
||||||
with self.subTest(msg="DMP"):
|
with self.subTest(msg="DMP"):
|
||||||
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||||
with self.subTest(msg="DetPMP"):
|
with self.subTest(msg="ProMP"):
|
||||||
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
|
||||||
|
|
||||||
def test_dmc_environment_determinism(self):
|
def test_dmc_environment_determinism(self):
|
||||||
"""Tests that identical seeds produce identical trajectories for DMC MP Envs."""
|
"""Tests that identical seeds produce identical trajectories for DMC MP Envs."""
|
||||||
with self.subTest(msg="DMP"):
|
with self.subTest(msg="DMP"):
|
||||||
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||||
with self.subTest(msg="DetPMP"):
|
with self.subTest(msg="ProMP"):
|
||||||
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
|
||||||
|
|
||||||
def test_metaworld_environment_determinism(self):
|
def test_metaworld_environment_determinism(self):
|
||||||
"""Tests that identical seeds produce identical trajectories for Metaworld MP Envs."""
|
"""Tests that identical seeds produce identical trajectories for Metaworld MP Envs."""
|
||||||
with self.subTest(msg="DMP"):
|
with self.subTest(msg="DMP"):
|
||||||
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||||
with self.subTest(msg="DetPMP"):
|
with self.subTest(msg="ProMP"):
|
||||||
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -81,13 +81,13 @@ class TestStepMetaWorlEnvironments(unittest.TestCase):
|
|||||||
def _verify_done(self, done):
|
def _verify_done(self, done):
|
||||||
self.assertIsInstance(done, bool, f"Returned {done} as done flag, expected bool.")
|
self.assertIsInstance(done, bool, f"Returned {done} as done flag, expected bool.")
|
||||||
|
|
||||||
def test_dmc_functionality(self):
|
def test_metaworld_functionality(self):
|
||||||
"""Tests that environments runs without errors using random actions."""
|
"""Tests that environments runs without errors using random actions."""
|
||||||
for env_id in ALL_ENVS:
|
for env_id in ALL_ENVS:
|
||||||
with self.subTest(msg=env_id):
|
with self.subTest(msg=env_id):
|
||||||
self._run_env(env_id)
|
self._run_env(env_id)
|
||||||
|
|
||||||
def test_dmc_determinism(self):
|
def test_metaworld_determinism(self):
|
||||||
"""Tests that identical seeds produce identical trajectories."""
|
"""Tests that identical seeds produce identical trajectories."""
|
||||||
seed = 0
|
seed = 0
|
||||||
# Iterate over two trajectories, which should have the same state and action sequence
|
# Iterate over two trajectories, which should have the same state and action sequence
|
||||||
|
Loading…
Reference in New Issue
Block a user