fixed OpenAI fetch tasks; added nicer imports
This commit is contained in:
parent
f5fcbf7f54
commit
a11965827d
@ -1,15 +1,12 @@
|
|||||||
import numpy as np
|
|
||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
|
from gym.wrappers import FlattenObservation
|
||||||
|
|
||||||
from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper
|
from alr_envs import classic_control, dmc, open_ai
|
||||||
from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper
|
|
||||||
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper
|
from alr_envs.utils.make_env_helpers import make_dmp_env
|
||||||
from alr_envs.dmc.manipulation.reach.reach_mp_wrapper import DMCReachSiteMPWrapper
|
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
||||||
from alr_envs.dmc.suite.ball_in_cup.ball_in_cup_mp_wrapper import DMCBallInCupMPWrapper
|
from alr_envs.utils.make_env_helpers import make_env
|
||||||
from alr_envs.dmc.suite.cartpole.cartpole_mp_wrapper import DMCCartpoleMPWrapper, DMCCartpoleThreePolesMPWrapper, \
|
from alr_envs.utils.make_env_helpers import make_env_rank
|
||||||
DMCCartpoleTwoPolesMPWrapper
|
|
||||||
from alr_envs.open_ai import reacher_v2, continuous_mountain_car, fetch
|
|
||||||
from alr_envs.dmc.suite.reacher.reacher_mp_wrapper import DMCReacherMPWrapper
|
|
||||||
|
|
||||||
# Mujoco
|
# Mujoco
|
||||||
|
|
||||||
@ -206,7 +203,7 @@ for v in versions:
|
|||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:{v}",
|
"name": f"alr_envs:{v}",
|
||||||
"wrappers": [SimpleReacherMPWrapper],
|
"wrappers": [classic_control.simple_reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2 if "long" not in v.lower() else 5,
|
"num_dof": 2 if "long" not in v.lower() else 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -225,7 +222,7 @@ register(
|
|||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:ViaPointReacher-v0",
|
"name": "alr_envs:ViaPointReacher-v0",
|
||||||
"wrappers": [ViaPointReacherMPWrapper],
|
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -238,6 +235,25 @@ register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='ViaPointReacherDetPMP-v0',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
# max_episode_steps=1,
|
||||||
|
kwargs={
|
||||||
|
"name": "alr_envs:ViaPointReacher-v0",
|
||||||
|
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 5,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"width": 0.025,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
## Hole Reacher
|
## Hole Reacher
|
||||||
versions = ["v0", "v1", "v2"]
|
versions = ["v0", "v1", "v2"]
|
||||||
for v in versions:
|
for v in versions:
|
||||||
@ -247,7 +263,7 @@ for v in versions:
|
|||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:HoleReacher-{v}",
|
"name": f"alr_envs:HoleReacher-{v}",
|
||||||
"wrappers": [HoleReacherMPWrapper],
|
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -267,7 +283,7 @@ for v in versions:
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:HoleReacher-{v}",
|
"name": f"alr_envs:HoleReacher-{v}",
|
||||||
"wrappers": [HoleReacherMPWrapper],
|
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -283,11 +299,6 @@ for v in versions:
|
|||||||
## Deep Mind Control Suite (DMC)
|
## Deep Mind Control Suite (DMC)
|
||||||
### Suite
|
### Suite
|
||||||
|
|
||||||
# tasks = ["ball_in_cup-catch", "reacher-easy", "reacher-hard", "cartpole-balance", "cartpole-balance_sparse",
|
|
||||||
# "cartpole-swingup", "cartpole-swingup_sparse", "cartpole-two_poles", "cartpole-three_poles"]
|
|
||||||
# wrappers = [DMCBallInCupMPWrapper, DMCReacherMPWrapper, DMCReacherMPWrapper, DMCCartpoleMPWrapper,
|
|
||||||
# partial(DMCCartpoleMPWrapper)]
|
|
||||||
# for t, w in zip(tasks, wrappers):
|
|
||||||
register(
|
register(
|
||||||
id=f'dmc_ball_in_cup-catch_dmp-v0',
|
id=f'dmc_ball_in_cup-catch_dmp-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
@ -296,7 +307,7 @@ register(
|
|||||||
"name": f"ball_in_cup-catch",
|
"name": f"ball_in_cup-catch",
|
||||||
"time_limit": 1,
|
"time_limit": 1,
|
||||||
"episode_length": 50,
|
"episode_length": 50,
|
||||||
"wrappers": [DMCBallInCupMPWrapper],
|
"wrappers": [dmc.suite.ball_in_cup.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -322,7 +333,7 @@ register(
|
|||||||
"name": f"ball_in_cup-catch",
|
"name": f"ball_in_cup-catch",
|
||||||
"time_limit": 1,
|
"time_limit": 1,
|
||||||
"episode_length": 50,
|
"episode_length": 50,
|
||||||
"wrappers": [DMCBallInCupMPWrapper],
|
"wrappers": [dmc.suite.ball_in_cup.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -339,7 +350,7 @@ register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO tune gains and episode length for all below
|
# TODO tune episode length for all below
|
||||||
register(
|
register(
|
||||||
id=f'dmc_reacher-easy_dmp-v0',
|
id=f'dmc_reacher-easy_dmp-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
@ -348,7 +359,7 @@ register(
|
|||||||
"name": f"reacher-easy",
|
"name": f"reacher-easy",
|
||||||
"time_limit": 1,
|
"time_limit": 1,
|
||||||
"episode_length": 50,
|
"episode_length": 50,
|
||||||
"wrappers": [DMCReacherMPWrapper],
|
"wrappers": [dmc.suite.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -374,7 +385,7 @@ register(
|
|||||||
"name": f"reacher-easy",
|
"name": f"reacher-easy",
|
||||||
"time_limit": 1,
|
"time_limit": 1,
|
||||||
"episode_length": 50,
|
"episode_length": 50,
|
||||||
"wrappers": [DMCReacherMPWrapper],
|
"wrappers": [dmc.suite.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -399,7 +410,7 @@ register(
|
|||||||
"name": f"reacher-hard",
|
"name": f"reacher-hard",
|
||||||
"time_limit": 1,
|
"time_limit": 1,
|
||||||
"episode_length": 50,
|
"episode_length": 50,
|
||||||
"wrappers": [DMCReacherMPWrapper],
|
"wrappers": [dmc.suite.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -425,7 +436,7 @@ register(
|
|||||||
"name": f"reacher-hard",
|
"name": f"reacher-hard",
|
||||||
"time_limit": 1,
|
"time_limit": 1,
|
||||||
"episode_length": 50,
|
"episode_length": 50,
|
||||||
"wrappers": [DMCReacherMPWrapper],
|
"wrappers": [dmc.suite.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -448,8 +459,9 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-balance",
|
"name": f"cartpole-balance",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [DMCCartpoleMPWrapper],
|
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -461,8 +473,8 @@ register(
|
|||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1,
|
"goal_scale": 0.1,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -474,8 +486,9 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-balance",
|
"name": f"cartpole-balance",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [DMCCartpoleMPWrapper],
|
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -485,8 +498,8 @@ register(
|
|||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -498,8 +511,9 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-balance_sparse",
|
"name": f"cartpole-balance_sparse",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [DMCCartpoleMPWrapper],
|
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -511,8 +525,8 @@ register(
|
|||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1,
|
"goal_scale": 0.1,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -524,8 +538,9 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-balance_sparse",
|
"name": f"cartpole-balance_sparse",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [DMCCartpoleMPWrapper],
|
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -535,8 +550,8 @@ register(
|
|||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -549,8 +564,9 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-swingup",
|
"name": f"cartpole-swingup",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [DMCCartpoleMPWrapper],
|
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -562,8 +578,8 @@ register(
|
|||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1,
|
"goal_scale": 0.1,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -575,8 +591,9 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-swingup",
|
"name": f"cartpole-swingup",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [DMCCartpoleMPWrapper],
|
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -586,8 +603,8 @@ register(
|
|||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -599,8 +616,9 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-swingup_sparse",
|
"name": f"cartpole-swingup_sparse",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [DMCCartpoleMPWrapper],
|
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -612,8 +630,8 @@ register(
|
|||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1,
|
"goal_scale": 0.1,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -625,8 +643,9 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-swingup_sparse",
|
"name": f"cartpole-swingup_sparse",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [DMCCartpoleMPWrapper],
|
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -636,8 +655,8 @@ register(
|
|||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -649,9 +668,10 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-two_poles",
|
"name": f"cartpole-two_poles",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)],
|
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)],
|
||||||
"wrappers": [DMCCartpoleTwoPolesMPWrapper],
|
"wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -663,8 +683,8 @@ register(
|
|||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1,
|
"goal_scale": 0.1,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -676,9 +696,10 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-two_poles",
|
"name": f"cartpole-two_poles",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)],
|
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)],
|
||||||
"wrappers": [DMCCartpoleTwoPolesMPWrapper],
|
"wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -688,8 +709,8 @@ register(
|
|||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -701,9 +722,10 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-three_poles",
|
"name": f"cartpole-three_poles",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)],
|
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)],
|
||||||
"wrappers": [DMCCartpoleThreePolesMPWrapper],
|
"wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -715,8 +737,8 @@ register(
|
|||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1,
|
"goal_scale": 0.1,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -728,9 +750,10 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"cartpole-three_poles",
|
"name": f"cartpole-three_poles",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)],
|
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)],
|
||||||
"wrappers": [DMCCartpoleThreePolesMPWrapper],
|
"wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -740,8 +763,8 @@ register(
|
|||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 50,
|
"p_gains": 10,
|
||||||
"d_gains": 1
|
"d_gains": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -757,7 +780,7 @@ register(
|
|||||||
"name": f"manipulation-reach_site_features",
|
"name": f"manipulation-reach_site_features",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
"episode_length": 250,
|
"episode_length": 250,
|
||||||
"wrappers": [DMCReachSiteMPWrapper],
|
"wrappers": [dmc.manipulation.reach.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 9,
|
"num_dof": 9,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -779,7 +802,7 @@ register(
|
|||||||
"name": f"manipulation-reach_site_features",
|
"name": f"manipulation-reach_site_features",
|
||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
"episode_length": 250,
|
"episode_length": 250,
|
||||||
"wrappers": [DMCReachSiteMPWrapper],
|
"wrappers": [dmc.manipulation.reach.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 9,
|
"num_dof": 9,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -798,7 +821,7 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
||||||
"wrappers": [continuous_mountain_car.MPWrapper],
|
"wrappers": [open_ai.classic_control.continuous_mountain_car.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 4,
|
"num_basis": 4,
|
||||||
@ -819,7 +842,7 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.mujoco:Reacher-v2",
|
"name": "gym.envs.mujoco:Reacher-v2",
|
||||||
"wrappers": [reacher_v2.MPWrapper],
|
"wrappers": [open_ai.mujoco.reacher_v2.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 6,
|
"num_basis": 6,
|
||||||
@ -840,7 +863,7 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
||||||
"wrappers": [fetch.MPWrapper],
|
"wrappers": [FlattenObservation, open_ai.robotics.fetch.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -857,7 +880,7 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchReachDense-v1",
|
"name": "gym.envs.robotics:FetchReachDense-v1",
|
||||||
"wrappers": [fetch.MPWrapper],
|
"wrappers": [FlattenObservation, open_ai.robotics.fetch.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class HoleReacherMPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
@ -0,0 +1 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class SimpleReacherMPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
@ -0,0 +1 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class ViaPointReacherMPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
@ -0,0 +1,5 @@
|
|||||||
|
# from alr_envs.dmc import manipulation, suite
|
||||||
|
from alr_envs.dmc.suite import ball_in_cup
|
||||||
|
from alr_envs.dmc.suite import reacher
|
||||||
|
from alr_envs.dmc.suite import cartpole
|
||||||
|
from alr_envs.dmc.manipulation import reach
|
@ -0,0 +1 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class DMCReachSiteMPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
@ -0,0 +1 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class DMCBallInCupMPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
@ -0,0 +1,3 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
||||||
|
from .mp_wrapper import TwoPolesMPWrapper
|
||||||
|
from .mp_wrapper import ThreePolesMPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class DMCCartpoleMPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
|
||||||
def __init__(self, env, n_poles: int = 1):
|
def __init__(self, env, n_poles: int = 1):
|
||||||
self.n_poles = n_poles
|
self.n_poles = n_poles
|
||||||
@ -39,13 +39,13 @@ class DMCCartpoleMPWrapper(MPEnvWrapper):
|
|||||||
return self.env.dt
|
return self.env.dt
|
||||||
|
|
||||||
|
|
||||||
class DMCCartpoleTwoPolesMPWrapper(DMCCartpoleMPWrapper):
|
class TwoPolesMPWrapper(MPWrapper):
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env, n_poles=2)
|
super().__init__(env, n_poles=2)
|
||||||
|
|
||||||
|
|
||||||
class DMCCartpoleThreePolesMPWrapper(DMCCartpoleMPWrapper):
|
class ThreePolesMPWrapper(MPWrapper):
|
||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env, n_poles=3)
|
super().__init__(env, n_poles=3)
|
@ -0,0 +1 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class DMCReacherMPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
@ -1,5 +1,5 @@
|
|||||||
from alr_envs.dmc.suite.ball_in_cup.ball_in_cup_mp_wrapper import DMCBallInCupMPWrapper
|
import alr_envs
|
||||||
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env
|
from alr_envs.dmc.suite.ball_in_cup.mp_wrapper import MPWrapper
|
||||||
|
|
||||||
|
|
||||||
def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
||||||
@ -17,13 +17,12 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
env = make_env(env_id, seed)
|
env = alr_envs.make_env(env_id, seed)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
print("observation shape:", env.observation_space.shape)
|
print("observation shape:", env.observation_space.shape)
|
||||||
print("action shape:", env.action_space.shape)
|
print("action shape:", env.action_space.shape)
|
||||||
|
|
||||||
# number of samples(multiple environment steps)
|
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
obs, reward, done, info = env.step(ac)
|
obs, reward, done, info = env.step(ac)
|
||||||
@ -63,7 +62,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
||||||
# You can also add other gym.Wrappers in case they are needed.
|
# You can also add other gym.Wrappers in case they are needed.
|
||||||
wrappers = [DMCBallInCupMPWrapper]
|
wrappers = [MPWrapper]
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -84,9 +83,9 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
# "frame_skip": 1
|
# "frame_skip": 1
|
||||||
}
|
}
|
||||||
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
# OR for a deterministic ProMP:
|
# OR for a deterministic ProMP:
|
||||||
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
# env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||||
|
|
||||||
# This renders the full MP trajectory
|
# This renders the full MP trajectory
|
||||||
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import warnings
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from alr_envs.utils.make_env_helpers import make_env, make_env_rank
|
import alr_envs
|
||||||
from alr_envs.utils.mp_env_async_sampler import AlrContextualMpEnvSampler, AlrMpEnvSampler, DummyDist
|
|
||||||
|
|
||||||
|
|
||||||
def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True):
|
def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True):
|
||||||
@ -23,7 +21,7 @@ def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
env = make_env(env_id, seed)
|
env = alr_envs.make_env(env_id, seed)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
print("Observation shape: ", env.observation_space.shape)
|
print("Observation shape: ", env.observation_space.shape)
|
||||||
@ -58,7 +56,7 @@ def example_async(env_id="alr_envs:HoleReacher-v0", n_cpu=4, seed=int('533D', 16
|
|||||||
Returns: Tuple of (obs, reward, done, info) with type np.ndarray
|
Returns: Tuple of (obs, reward, done, info) with type np.ndarray
|
||||||
|
|
||||||
"""
|
"""
|
||||||
env = gym.vector.AsyncVectorEnv([make_env_rank(env_id, seed, i) for i in range(n_cpu)])
|
env = gym.vector.AsyncVectorEnv([alr_envs.make_env_rank(env_id, seed, i) for i in range(n_cpu)])
|
||||||
# OR
|
# OR
|
||||||
# envs = gym.vector.AsyncVectorEnv([make_env(env_id, seed + i) for i in range(n_cpu)])
|
# envs = gym.vector.AsyncVectorEnv([make_env(env_id, seed + i) for i in range(n_cpu)])
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from alr_envs import HoleReacherMPWrapper
|
from alr_envs import MPWrapper
|
||||||
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env
|
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env
|
||||||
|
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
||||||
# You can also add other gym.Wrappers in case they are needed.
|
# You can also add other gym.Wrappers in case they are needed.
|
||||||
wrappers = [HoleReacherMPWrapper]
|
wrappers = [MPWrapper]
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
|
74
alr_envs/examples/pd_control_gain_tuning.py
Normal file
74
alr_envs/examples/pd_control_gain_tuning.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import numpy as np
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
from alr_envs import dmc
|
||||||
|
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
||||||
|
|
||||||
|
# This might work for some environments, however, please verify either way the correct trajectory information
|
||||||
|
# for your environment are extracted below
|
||||||
|
SEED = 10
|
||||||
|
env_id = "cartpole-swingup"
|
||||||
|
wrappers = [dmc.suite.cartpole.MPWrapper]
|
||||||
|
|
||||||
|
mp_kwargs = {
|
||||||
|
"num_dof": 1,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"width": 0.025,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": 10,
|
||||||
|
"d_gains": 10 # a good starting point is the sqrt of p_gains
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kwargs = dict(time_limit=2, episode_length=200)
|
||||||
|
|
||||||
|
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
# Plot difference between real trajectory and target MP trajectory
|
||||||
|
env.reset()
|
||||||
|
pos, vel = env.mp_rollout(env.action_space.sample())
|
||||||
|
|
||||||
|
base_shape = env.full_action_space.shape
|
||||||
|
actual_pos = np.zeros((len(pos), *base_shape))
|
||||||
|
actual_pos_ball = np.zeros((len(pos), *base_shape))
|
||||||
|
actual_vel = np.zeros((len(pos), *base_shape))
|
||||||
|
act = np.zeros((len(pos), *base_shape))
|
||||||
|
|
||||||
|
for t, pos_vel in enumerate(zip(pos, vel)):
|
||||||
|
actions = env.policy.get_action(pos_vel[0], pos_vel[1])
|
||||||
|
actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
|
||||||
|
_, _, _, _ = env.env.step(actions)
|
||||||
|
act[t, :] = actions
|
||||||
|
# TODO verify for your environment
|
||||||
|
actual_pos[t, :] = env.current_pos
|
||||||
|
# actual_pos_ball[t, :] = env.physics.data.qpos[2:]
|
||||||
|
actual_vel[t, :] = env.current_vel
|
||||||
|
|
||||||
|
plt.figure(figsize=(15, 5))
|
||||||
|
|
||||||
|
plt.subplot(131)
|
||||||
|
plt.title("Position")
|
||||||
|
plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))])
|
||||||
|
# plt.plot(actual_pos_ball, label="true pos ball")
|
||||||
|
plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))])
|
||||||
|
plt.xlabel("Episode steps")
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
plt.subplot(132)
|
||||||
|
plt.title("Velocity")
|
||||||
|
plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))])
|
||||||
|
plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))])
|
||||||
|
plt.xlabel("Episode steps")
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
plt.subplot(133)
|
||||||
|
plt.title("Actions")
|
||||||
|
plt.plot(act, c="C0"), # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))])
|
||||||
|
plt.xlabel("Episode steps")
|
||||||
|
# plt.legend()
|
||||||
|
plt.show()
|
@ -2,7 +2,7 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class BallInACupMPWrapper(MPEnvWrapper):
|
class BallInACupMPWrapper(MPEnvWrapper):
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
from alr_envs.open_ai.mujoco import reacher_v2
|
||||||
|
from alr_envs.open_ai.robotics import fetch
|
||||||
|
from alr_envs.open_ai.classic_control import continuous_mountain_car
|
0
alr_envs/open_ai/classic_control/__init__.py
Normal file
0
alr_envs/open_ai/classic_control/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
@ -1 +0,0 @@
|
|||||||
from alr_envs.open_ai.continuous_mountain_car.mp_wrapper import MPWrapper
|
|
@ -1 +0,0 @@
|
|||||||
from alr_envs.open_ai.fetch.mp_wrapper import MPWrapper
|
|
@ -1,22 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(MPEnvWrapper):
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return self.unwrapped._get_obs()["observation"][-5:-1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return self.unwrapped._get_obs()["observation"][:4]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self):
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return self.env.dt
|
|
0
alr_envs/open_ai/mujoco/__init__.py
Normal file
0
alr_envs/open_ai/mujoco/__init__.py
Normal file
1
alr_envs/open_ai/mujoco/reacher_v2/__init__.py
Normal file
1
alr_envs/open_ai/mujoco/reacher_v2/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
@ -1 +0,0 @@
|
|||||||
from alr_envs.open_ai.reacher_v2.mp_wrapper import MPWrapper
|
|
0
alr_envs/open_ai/robotics/__init__.py
Normal file
0
alr_envs/open_ai/robotics/__init__.py
Normal file
1
alr_envs/open_ai/robotics/fetch/__init__.py
Normal file
1
alr_envs/open_ai/robotics/fetch/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
49
alr_envs/open_ai/robotics/fetch/mp_wrapper.py
Normal file
49
alr_envs/open_ai/robotics/fetch/mp_wrapper.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
return np.hstack([
|
||||||
|
[False] * 3, # achieved goal
|
||||||
|
[True] * 3, # desired/true goal
|
||||||
|
[False] * 3, # grip pos
|
||||||
|
[True, True, False] * int(self.has_object), # object position
|
||||||
|
[True, True, False] * int(self.has_object), # object relative position
|
||||||
|
[False] * 2, # gripper state
|
||||||
|
[False] * 3 * int(self.has_object), # object rotation
|
||||||
|
[False] * 3 * int(self.has_object), # object velocity position
|
||||||
|
[False] * 3 * int(self.has_object), # object velocity rotation
|
||||||
|
[False] * 3, # grip velocity position
|
||||||
|
[False] * 2, # gripper velocity
|
||||||
|
]).astype(bool)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray]:
|
||||||
|
dt = self.sim.nsubsteps * self.sim.model.opt.timestep
|
||||||
|
grip_velp = self.sim.data.get_site_xvelp("robot0:grip") * dt
|
||||||
|
# gripper state should be symmetric for left and right.
|
||||||
|
# They are controlled with only one action for both gripper joints
|
||||||
|
gripper_state = self.sim.data.get_joint_qvel('robot0:r_gripper_finger_joint') * dt
|
||||||
|
return np.hstack([grip_velp, gripper_state])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
grip_pos = self.sim.data.get_site_xpos("robot0:grip")
|
||||||
|
# gripper state should be symmetric for left and right.
|
||||||
|
# They are controlled with only one action for both gripper joints
|
||||||
|
gripper_state = self.sim.data.get_joint_qpos('robot0:r_gripper_finger_joint')
|
||||||
|
return np.hstack([grip_pos, gripper_state])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self):
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
@ -37,7 +37,6 @@ def make(
|
|||||||
episode_length = 250 if domain_name == "manipulation" else 1000
|
episode_length = 250 if domain_name == "manipulation" else 1000
|
||||||
|
|
||||||
max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
|
max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
|
||||||
|
|
||||||
if env_id not in gym.envs.registry.env_specs:
|
if env_id not in gym.envs.registry.env_specs:
|
||||||
task_kwargs = {'random': seed}
|
task_kwargs = {'random': seed}
|
||||||
# if seed is not None:
|
# if seed is not None:
|
||||||
@ -46,7 +45,7 @@ def make(
|
|||||||
task_kwargs['time_limit'] = time_limit
|
task_kwargs['time_limit'] = time_limit
|
||||||
register(
|
register(
|
||||||
id=env_id,
|
id=env_id,
|
||||||
entry_point='alr_envs.utils.dmc2gym_wrapper:DMCWrapper',
|
entry_point='alr_envs.utils.dmc_wrapper:DMCWrapper',
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
domain_name=domain_name,
|
domain_name=domain_name,
|
||||||
task_name=task_name,
|
task_name=task_name,
|
||||||
|
@ -33,11 +33,14 @@ def _spec_to_box(spec):
|
|||||||
|
|
||||||
|
|
||||||
def _flatten_obs(obs: collections.MutableMapping):
|
def _flatten_obs(obs: collections.MutableMapping):
|
||||||
# obs_pieces = []
|
"""
|
||||||
# for v in obs.values():
|
Flattens an observation of type MutableMapping, e.g. a dict to a 1D array.
|
||||||
# flat = np.array([v]) if np.isscalar(v) else v.ravel()
|
Args:
|
||||||
# obs_pieces.append(flat)
|
obs: observation to flatten
|
||||||
# return np.concatenate(obs_pieces, axis=0)
|
|
||||||
|
Returns: 1D array of observation
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
if not isinstance(obs, collections.MutableMapping):
|
if not isinstance(obs, collections.MutableMapping):
|
||||||
raise ValueError(f'Requires dict-like observations structure. {type(obs)} found.')
|
raise ValueError(f'Requires dict-like observations structure. {type(obs)} found.')
|
||||||
@ -52,19 +55,19 @@ def _flatten_obs(obs: collections.MutableMapping):
|
|||||||
class DMCWrapper(core.Env):
|
class DMCWrapper(core.Env):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
domain_name,
|
domain_name: str,
|
||||||
task_name,
|
task_name: str,
|
||||||
task_kwargs={},
|
task_kwargs: dict = {},
|
||||||
visualize_reward=True,
|
visualize_reward: bool = True,
|
||||||
from_pixels=False,
|
from_pixels: bool = False,
|
||||||
height=84,
|
height: int = 84,
|
||||||
width=84,
|
width: int = 84,
|
||||||
camera_id=0,
|
camera_id: int = 0,
|
||||||
frame_skip=1,
|
frame_skip: int = 1,
|
||||||
environment_kwargs=None,
|
environment_kwargs: dict = None,
|
||||||
channels_first=True
|
channels_first: bool = True
|
||||||
):
|
):
|
||||||
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
|
assert 'random' in task_kwargs, 'Please specify a seed for deterministic behavior.'
|
||||||
self._from_pixels = from_pixels
|
self._from_pixels = from_pixels
|
||||||
self._height = height
|
self._height = height
|
||||||
self._width = width
|
self._width = width
|
||||||
@ -74,7 +77,7 @@ class DMCWrapper(core.Env):
|
|||||||
|
|
||||||
# create task
|
# create task
|
||||||
if domain_name == "manipulation":
|
if domain_name == "manipulation":
|
||||||
assert not from_pixels, \
|
assert not from_pixels and not task_name.endswith("_vision"), \
|
||||||
"TODO: Vision interface for manipulation is different to suite and needs to be implemented"
|
"TODO: Vision interface for manipulation is different to suite and needs to be implemented"
|
||||||
self._env = manipulation.load(environment_name=task_name, seed=task_kwargs['random'])
|
self._env = manipulation.load(environment_name=task_name, seed=task_kwargs['random'])
|
||||||
else:
|
else:
|
||||||
@ -169,11 +172,12 @@ class DMCWrapper(core.Env):
|
|||||||
if self._last_state is None:
|
if self._last_state is None:
|
||||||
raise ValueError('Environment not ready to render. Call reset() first.')
|
raise ValueError('Environment not ready to render. Call reset() first.')
|
||||||
|
|
||||||
|
camera_id = camera_id or self._camera_id
|
||||||
|
|
||||||
# assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
|
# assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
|
||||||
if mode == "rgb_array":
|
if mode == "rgb_array":
|
||||||
height = height or self._height
|
height = height or self._height
|
||||||
width = width or self._width
|
width = width or self._width
|
||||||
camera_id = camera_id or self._camera_id
|
|
||||||
return self._env.physics.render(height=height, width=width, camera_id=camera_id)
|
return self._env.physics.render(height=height, width=width, camera_id=camera_id)
|
||||||
|
|
||||||
elif mode == 'human':
|
elif mode == 'human':
|
||||||
@ -184,7 +188,8 @@ class DMCWrapper(core.Env):
|
|||||||
self.viewer = rendering.SimpleImageViewer()
|
self.viewer = rendering.SimpleImageViewer()
|
||||||
# Render max available buffer size. Larger is only possible by altering the XML.
|
# Render max available buffer size. Larger is only possible by altering the XML.
|
||||||
img = self._env.physics.render(height=self._env.physics.model.vis.global_.offheight,
|
img = self._env.physics.render(height=self._env.physics.model.vis.global_.offheight,
|
||||||
width=self._env.physics.model.vis.global_.offwidth)
|
width=self._env.physics.model.vis.global_.offwidth,
|
||||||
|
camera_id=camera_id)
|
||||||
self.viewer.imshow(img)
|
self.viewer.imshow(img)
|
||||||
return self.viewer.isopen
|
return self.viewer.isopen
|
||||||
|
|
@ -2,13 +2,14 @@ import logging
|
|||||||
from typing import Iterable, List, Type, Union
|
from typing import Iterable, List, Type, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api import MPEnvWrapper
|
||||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||||
|
|
||||||
|
|
||||||
def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs):
|
def make_env_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
TODO: Do we need this?
|
TODO: Do we need this?
|
||||||
Generate a callable to create a new gym environment with a given seed.
|
Generate a callable to create a new gym environment with a given seed.
|
||||||
@ -22,11 +23,16 @@ def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs):
|
|||||||
env_id: name of the environment
|
env_id: name of the environment
|
||||||
seed: seed for deterministic behaviour
|
seed: seed for deterministic behaviour
|
||||||
rank: environment rank for deterministic over multiple seeds behaviour
|
rank: environment rank for deterministic over multiple seeds behaviour
|
||||||
|
return_callable: If True returns a callable to create the environment instead of the environment itself.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return lambda: make_env(env_id, seed + rank, **kwargs)
|
|
||||||
|
def f():
|
||||||
|
return make_env(env_id, seed + rank, **kwargs)
|
||||||
|
|
||||||
|
return f if return_callable else f()
|
||||||
|
|
||||||
|
|
||||||
def make_env(env_id: str, seed, **kwargs):
|
def make_env(env_id: str, seed, **kwargs):
|
||||||
@ -103,6 +109,9 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs
|
|||||||
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||||
|
|
||||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||||
|
|
||||||
|
verify_dof(_env, mp_kwargs.get("num_dof"))
|
||||||
|
|
||||||
return DmpWrapper(_env, **mp_kwargs)
|
return DmpWrapper(_env, **mp_kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -120,6 +129,9 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa
|
|||||||
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||||
|
|
||||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||||
|
|
||||||
|
verify_dof(_env, mp_kwargs.get("num_dof"))
|
||||||
|
|
||||||
return DetPMPWrapper(_env, **mp_kwargs)
|
return DetPMPWrapper(_env, **mp_kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -185,5 +197,12 @@ def verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[N
|
|||||||
"""
|
"""
|
||||||
if mp_time_limit is not None and env_time_limit is not None:
|
if mp_time_limit is not None and env_time_limit is not None:
|
||||||
assert mp_time_limit == env_time_limit, \
|
assert mp_time_limit == env_time_limit, \
|
||||||
f"The manually specified 'time_limit' of {env_time_limit}s does not match " \
|
f"The specified 'time_limit' of {env_time_limit}s does not match " \
|
||||||
f"the duration of {mp_time_limit}s for the MP."
|
f"the duration of {mp_time_limit}s for the MP."
|
||||||
|
|
||||||
|
|
||||||
|
def verify_dof(base_env: gym.Env, dof: int):
|
||||||
|
action_shape = np.prod(base_env.action_space.shape)
|
||||||
|
assert dof == action_shape, \
|
||||||
|
f"The specified degrees of freedom ('num_dof') {dof} do not match " \
|
||||||
|
f"the action space of {action_shape} the base environments"
|
||||||
|
@ -19,4 +19,3 @@ def angle_normalize(x, type="deg"):
|
|||||||
|
|
||||||
two_pi = 2 * np.pi
|
two_pi = 2 * np.pi
|
||||||
return x - two_pi * np.floor((x + np.pi) / two_pi)
|
return x - two_pi * np.floor((x + np.pi) / two_pi)
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@ -3,7 +3,7 @@ from setuptools import setup
|
|||||||
setup(
|
setup(
|
||||||
name='alr_envs',
|
name='alr_envs',
|
||||||
version='0.0.1',
|
version='0.0.1',
|
||||||
packages=['alr_envs', 'alr_envs.classic_control', 'alr_envs.open_ai', 'alr_envs.mujoco', 'alr_envs.stochastic_search',
|
packages=['alr_envs', 'alr_envs.classic_control', 'alr_envs.open_ai', 'alr_envs.mujoco', 'alr_envs.dmc',
|
||||||
'alr_envs.utils'],
|
'alr_envs.utils'],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'gym',
|
'gym',
|
||||||
|
@ -88,11 +88,8 @@ class TestEnvironments(unittest.TestCase):
|
|||||||
def test_environment_determinism(self):
|
def test_environment_determinism(self):
|
||||||
"""Tests that identical seeds produce identical trajectories."""
|
"""Tests that identical seeds produce identical trajectories."""
|
||||||
seed = 0
|
seed = 0
|
||||||
# Iterate over two trajectories generated using identical sequences of
|
# Iterate over two trajectories, which should have the same state and action sequence
|
||||||
# random actions, and with identical task random states. Check that the
|
|
||||||
# observations, rewards, discounts and step types are identical.
|
|
||||||
for spec in ALL_SPECS:
|
for spec in ALL_SPECS:
|
||||||
# try:
|
|
||||||
with self.subTest(msg=spec.id):
|
with self.subTest(msg=spec.id):
|
||||||
self._run_env(spec.id)
|
self._run_env(spec.id)
|
||||||
traj1 = self._run_env(spec.id, seed=seed)
|
traj1 = self._run_env(spec.id, seed=seed)
|
||||||
|
Loading…
Reference in New Issue
Block a user