fixed OpenAI fetch tasks; added nicer imports
This commit is contained in:
parent
f5fcbf7f54
commit
a11965827d
@ -1,15 +1,12 @@
|
||||
import numpy as np
|
||||
from gym.envs.registration import register
|
||||
from gym.wrappers import FlattenObservation
|
||||
|
||||
from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper
|
||||
from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper
|
||||
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper
|
||||
from alr_envs.dmc.manipulation.reach.reach_mp_wrapper import DMCReachSiteMPWrapper
|
||||
from alr_envs.dmc.suite.ball_in_cup.ball_in_cup_mp_wrapper import DMCBallInCupMPWrapper
|
||||
from alr_envs.dmc.suite.cartpole.cartpole_mp_wrapper import DMCCartpoleMPWrapper, DMCCartpoleThreePolesMPWrapper, \
|
||||
DMCCartpoleTwoPolesMPWrapper
|
||||
from alr_envs.open_ai import reacher_v2, continuous_mountain_car, fetch
|
||||
from alr_envs.dmc.suite.reacher.reacher_mp_wrapper import DMCReacherMPWrapper
|
||||
from alr_envs import classic_control, dmc, open_ai
|
||||
|
||||
from alr_envs.utils.make_env_helpers import make_dmp_env
|
||||
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
||||
from alr_envs.utils.make_env_helpers import make_env
|
||||
from alr_envs.utils.make_env_helpers import make_env_rank
|
||||
|
||||
# Mujoco
|
||||
|
||||
@ -206,7 +203,7 @@ for v in versions:
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"alr_envs:{v}",
|
||||
"wrappers": [SimpleReacherMPWrapper],
|
||||
"wrappers": [classic_control.simple_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2 if "long" not in v.lower() else 5,
|
||||
"num_basis": 5,
|
||||
@ -225,7 +222,7 @@ register(
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": "alr_envs:ViaPointReacher-v0",
|
||||
"wrappers": [ViaPointReacherMPWrapper],
|
||||
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
@ -238,6 +235,25 @@ register(
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ViaPointReacherDetPMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": "alr_envs:ViaPointReacher-v0",
|
||||
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"width": 0.025,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
## Hole Reacher
|
||||
versions = ["v0", "v1", "v2"]
|
||||
for v in versions:
|
||||
@ -247,7 +263,7 @@ for v in versions:
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"alr_envs:HoleReacher-{v}",
|
||||
"wrappers": [HoleReacherMPWrapper],
|
||||
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
@ -267,7 +283,7 @@ for v in versions:
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"alr_envs:HoleReacher-{v}",
|
||||
"wrappers": [HoleReacherMPWrapper],
|
||||
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
@ -283,11 +299,6 @@ for v in versions:
|
||||
## Deep Mind Control Suite (DMC)
|
||||
### Suite
|
||||
|
||||
# tasks = ["ball_in_cup-catch", "reacher-easy", "reacher-hard", "cartpole-balance", "cartpole-balance_sparse",
|
||||
# "cartpole-swingup", "cartpole-swingup_sparse", "cartpole-two_poles", "cartpole-three_poles"]
|
||||
# wrappers = [DMCBallInCupMPWrapper, DMCReacherMPWrapper, DMCReacherMPWrapper, DMCCartpoleMPWrapper,
|
||||
# partial(DMCCartpoleMPWrapper)]
|
||||
# for t, w in zip(tasks, wrappers):
|
||||
register(
|
||||
id=f'dmc_ball_in_cup-catch_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
@ -296,7 +307,7 @@ register(
|
||||
"name": f"ball_in_cup-catch",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCBallInCupMPWrapper],
|
||||
"wrappers": [dmc.suite.ball_in_cup.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
@ -322,7 +333,7 @@ register(
|
||||
"name": f"ball_in_cup-catch",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCBallInCupMPWrapper],
|
||||
"wrappers": [dmc.suite.ball_in_cup.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
@ -339,7 +350,7 @@ register(
|
||||
}
|
||||
)
|
||||
|
||||
# TODO tune gains and episode length for all below
|
||||
# TODO tune episode length for all below
|
||||
register(
|
||||
id=f'dmc_reacher-easy_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
@ -348,7 +359,7 @@ register(
|
||||
"name": f"reacher-easy",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCReacherMPWrapper],
|
||||
"wrappers": [dmc.suite.reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
@ -374,7 +385,7 @@ register(
|
||||
"name": f"reacher-easy",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCReacherMPWrapper],
|
||||
"wrappers": [dmc.suite.reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
@ -399,7 +410,7 @@ register(
|
||||
"name": f"reacher-hard",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCReacherMPWrapper],
|
||||
"wrappers": [dmc.suite.reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
@ -425,7 +436,7 @@ register(
|
||||
"name": f"reacher-hard",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCReacherMPWrapper],
|
||||
"wrappers": [dmc.suite.reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
@ -448,8 +459,9 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-balance",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -461,8 +473,8 @@ register(
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -474,8 +486,9 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-balance",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -485,8 +498,8 @@ register(
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -498,8 +511,9 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-balance_sparse",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -511,8 +525,8 @@ register(
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -524,8 +538,9 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-balance_sparse",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -535,8 +550,8 @@ register(
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -549,8 +564,9 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-swingup",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -562,8 +578,8 @@ register(
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -575,8 +591,9 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-swingup",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -586,8 +603,8 @@ register(
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -599,8 +616,9 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-swingup_sparse",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -612,8 +630,8 @@ register(
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -625,8 +643,9 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-swingup_sparse",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -636,8 +655,8 @@ register(
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -649,9 +668,10 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-two_poles",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)],
|
||||
"wrappers": [DMCCartpoleTwoPolesMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -663,8 +683,8 @@ register(
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -676,9 +696,10 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-two_poles",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)],
|
||||
"wrappers": [DMCCartpoleTwoPolesMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -688,8 +709,8 @@ register(
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -701,9 +722,10 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-three_poles",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)],
|
||||
"wrappers": [DMCCartpoleThreePolesMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -715,8 +737,8 @@ register(
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -728,9 +750,10 @@ register(
|
||||
kwargs={
|
||||
"name": f"cartpole-three_poles",
|
||||
# "time_limit": 1,
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)],
|
||||
"wrappers": [DMCCartpoleThreePolesMPWrapper],
|
||||
"wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
@ -740,8 +763,8 @@ register(
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
"p_gains": 10,
|
||||
"d_gains": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -757,7 +780,7 @@ register(
|
||||
"name": f"manipulation-reach_site_features",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 250,
|
||||
"wrappers": [DMCReachSiteMPWrapper],
|
||||
"wrappers": [dmc.manipulation.reach.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 9,
|
||||
"num_basis": 5,
|
||||
@ -779,7 +802,7 @@ register(
|
||||
"name": f"manipulation-reach_site_features",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 250,
|
||||
"wrappers": [DMCReachSiteMPWrapper],
|
||||
"wrappers": [dmc.manipulation.reach.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 9,
|
||||
"num_basis": 5,
|
||||
@ -798,7 +821,7 @@ register(
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
||||
"wrappers": [continuous_mountain_car.MPWrapper],
|
||||
"wrappers": [open_ai.classic_control.continuous_mountain_car.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 4,
|
||||
@ -819,7 +842,7 @@ register(
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": "gym.envs.mujoco:Reacher-v2",
|
||||
"wrappers": [reacher_v2.MPWrapper],
|
||||
"wrappers": [open_ai.mujoco.reacher_v2.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 6,
|
||||
@ -840,7 +863,7 @@ register(
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
||||
"wrappers": [fetch.MPWrapper],
|
||||
"wrappers": [FlattenObservation, open_ai.robotics.fetch.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 4,
|
||||
"num_basis": 5,
|
||||
@ -857,7 +880,7 @@ register(
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": "gym.envs.robotics:FetchReachDense-v1",
|
||||
"wrappers": [fetch.MPWrapper],
|
||||
"wrappers": [FlattenObservation, open_ai.robotics.fetch.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 4,
|
||||
"num_basis": 5,
|
||||
|
@ -0,0 +1 @@
|
||||
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class HoleReacherMPWrapper(MPEnvWrapper):
|
||||
class MPWrapper(MPEnvWrapper):
|
||||
@property
|
||||
def active_obs(self):
|
||||
return np.hstack([
|
@ -0,0 +1 @@
|
||||
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class SimpleReacherMPWrapper(MPEnvWrapper):
|
||||
class MPWrapper(MPEnvWrapper):
|
||||
@property
|
||||
def active_obs(self):
|
||||
return np.hstack([
|
@ -0,0 +1 @@
|
||||
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class ViaPointReacherMPWrapper(MPEnvWrapper):
|
||||
class MPWrapper(MPEnvWrapper):
|
||||
@property
|
||||
def active_obs(self):
|
||||
return np.hstack([
|
@ -0,0 +1,5 @@
|
||||
# from alr_envs.dmc import manipulation, suite
|
||||
from alr_envs.dmc.suite import ball_in_cup
|
||||
from alr_envs.dmc.suite import reacher
|
||||
from alr_envs.dmc.suite import cartpole
|
||||
from alr_envs.dmc.manipulation import reach
|
@ -0,0 +1 @@
|
||||
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class DMCReachSiteMPWrapper(MPEnvWrapper):
|
||||
class MPWrapper(MPEnvWrapper):
|
||||
|
||||
@property
|
||||
def active_obs(self):
|
@ -0,0 +1 @@
|
||||
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class DMCBallInCupMPWrapper(MPEnvWrapper):
|
||||
class MPWrapper(MPEnvWrapper):
|
||||
|
||||
@property
|
||||
def active_obs(self):
|
@ -0,0 +1,3 @@
|
||||
from .mp_wrapper import MPWrapper
|
||||
from .mp_wrapper import TwoPolesMPWrapper
|
||||
from .mp_wrapper import ThreePolesMPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class DMCCartpoleMPWrapper(MPEnvWrapper):
|
||||
class MPWrapper(MPEnvWrapper):
|
||||
|
||||
def __init__(self, env, n_poles: int = 1):
|
||||
self.n_poles = n_poles
|
||||
@ -39,13 +39,13 @@ class DMCCartpoleMPWrapper(MPEnvWrapper):
|
||||
return self.env.dt
|
||||
|
||||
|
||||
class DMCCartpoleTwoPolesMPWrapper(DMCCartpoleMPWrapper):
|
||||
class TwoPolesMPWrapper(MPWrapper):
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env, n_poles=2)
|
||||
|
||||
|
||||
class DMCCartpoleThreePolesMPWrapper(DMCCartpoleMPWrapper):
|
||||
class ThreePolesMPWrapper(MPWrapper):
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env, n_poles=3)
|
@ -0,0 +1 @@
|
||||
from .mp_wrapper import MPWrapper
|
@ -2,10 +2,10 @@ from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class DMCReacherMPWrapper(MPEnvWrapper):
|
||||
class MPWrapper(MPEnvWrapper):
|
||||
|
||||
@property
|
||||
def active_obs(self):
|
@ -1,5 +1,5 @@
|
||||
from alr_envs.dmc.suite.ball_in_cup.ball_in_cup_mp_wrapper import DMCBallInCupMPWrapper
|
||||
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env
|
||||
import alr_envs
|
||||
from alr_envs.dmc.suite.ball_in_cup.mp_wrapper import MPWrapper
|
||||
|
||||
|
||||
def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
||||
@ -17,13 +17,12 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
||||
Returns:
|
||||
|
||||
"""
|
||||
env = make_env(env_id, seed)
|
||||
env = alr_envs.make_env(env_id, seed)
|
||||
rewards = 0
|
||||
obs = env.reset()
|
||||
print("observation shape:", env.observation_space.shape)
|
||||
print("action shape:", env.action_space.shape)
|
||||
|
||||
# number of samples(multiple environment steps)
|
||||
for i in range(iterations):
|
||||
ac = env.action_space.sample()
|
||||
obs, reward, done, info = env.step(ac)
|
||||
@ -63,7 +62,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
|
||||
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
||||
# You can also add other gym.Wrappers in case they are needed.
|
||||
wrappers = [DMCBallInCupMPWrapper]
|
||||
wrappers = [MPWrapper]
|
||||
mp_kwargs = {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
@ -84,9 +83,9 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
"episode_length": 1000,
|
||||
# "frame_skip": 1
|
||||
}
|
||||
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||
# OR for a deterministic ProMP:
|
||||
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||
# env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||
|
||||
# This renders the full MP trajectory
|
||||
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||
|
@ -1,11 +1,9 @@
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from alr_envs.utils.make_env_helpers import make_env, make_env_rank
|
||||
from alr_envs.utils.mp_env_async_sampler import AlrContextualMpEnvSampler, AlrMpEnvSampler, DummyDist
|
||||
import alr_envs
|
||||
|
||||
|
||||
def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True):
|
||||
@ -23,7 +21,7 @@ def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True):
|
||||
|
||||
"""
|
||||
|
||||
env = make_env(env_id, seed)
|
||||
env = alr_envs.make_env(env_id, seed)
|
||||
rewards = 0
|
||||
obs = env.reset()
|
||||
print("Observation shape: ", env.observation_space.shape)
|
||||
@ -58,7 +56,7 @@ def example_async(env_id="alr_envs:HoleReacher-v0", n_cpu=4, seed=int('533D', 16
|
||||
Returns: Tuple of (obs, reward, done, info) with type np.ndarray
|
||||
|
||||
"""
|
||||
env = gym.vector.AsyncVectorEnv([make_env_rank(env_id, seed, i) for i in range(n_cpu)])
|
||||
env = gym.vector.AsyncVectorEnv([alr_envs.make_env_rank(env_id, seed, i) for i in range(n_cpu)])
|
||||
# OR
|
||||
# envs = gym.vector.AsyncVectorEnv([make_env(env_id, seed + i) for i in range(n_cpu)])
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from alr_envs import HoleReacherMPWrapper
|
||||
from alr_envs import MPWrapper
|
||||
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
|
||||
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
||||
# You can also add other gym.Wrappers in case they are needed.
|
||||
wrappers = [HoleReacherMPWrapper]
|
||||
wrappers = [MPWrapper]
|
||||
mp_kwargs = {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
|
74
alr_envs/examples/pd_control_gain_tuning.py
Normal file
74
alr_envs/examples/pd_control_gain_tuning.py
Normal file
@ -0,0 +1,74 @@
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from alr_envs import dmc
|
||||
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
||||
|
||||
# This might work for some environments, however, please verify either way the correct trajectory information
|
||||
# for your environment are extracted below
|
||||
SEED = 10
|
||||
env_id = "cartpole-swingup"
|
||||
wrappers = [dmc.suite.cartpole.MPWrapper]
|
||||
|
||||
mp_kwargs = {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 10,
|
||||
"d_gains": 10 # a good starting point is the sqrt of p_gains
|
||||
}
|
||||
}
|
||||
|
||||
kwargs = dict(time_limit=2, episode_length=200)
|
||||
|
||||
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
||||
**kwargs)
|
||||
|
||||
# Plot difference between real trajectory and target MP trajectory
|
||||
env.reset()
|
||||
pos, vel = env.mp_rollout(env.action_space.sample())
|
||||
|
||||
base_shape = env.full_action_space.shape
|
||||
actual_pos = np.zeros((len(pos), *base_shape))
|
||||
actual_pos_ball = np.zeros((len(pos), *base_shape))
|
||||
actual_vel = np.zeros((len(pos), *base_shape))
|
||||
act = np.zeros((len(pos), *base_shape))
|
||||
|
||||
for t, pos_vel in enumerate(zip(pos, vel)):
|
||||
actions = env.policy.get_action(pos_vel[0], pos_vel[1])
|
||||
actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
|
||||
_, _, _, _ = env.env.step(actions)
|
||||
act[t, :] = actions
|
||||
# TODO verify for your environment
|
||||
actual_pos[t, :] = env.current_pos
|
||||
# actual_pos_ball[t, :] = env.physics.data.qpos[2:]
|
||||
actual_vel[t, :] = env.current_vel
|
||||
|
||||
plt.figure(figsize=(15, 5))
|
||||
|
||||
plt.subplot(131)
|
||||
plt.title("Position")
|
||||
plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))])
|
||||
# plt.plot(actual_pos_ball, label="true pos ball")
|
||||
plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))])
|
||||
plt.xlabel("Episode steps")
|
||||
plt.legend()
|
||||
|
||||
plt.subplot(132)
|
||||
plt.title("Velocity")
|
||||
plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))])
|
||||
plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))])
|
||||
plt.xlabel("Episode steps")
|
||||
plt.legend()
|
||||
|
||||
plt.subplot(133)
|
||||
plt.title("Actions")
|
||||
plt.plot(act, c="C0"), # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))])
|
||||
plt.xlabel("Episode steps")
|
||||
# plt.legend()
|
||||
plt.show()
|
@ -2,7 +2,7 @@ from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class BallInACupMPWrapper(MPEnvWrapper):
|
||||
|
@ -0,0 +1,3 @@
|
||||
from alr_envs.open_ai.mujoco import reacher_v2
|
||||
from alr_envs.open_ai.robotics import fetch
|
||||
from alr_envs.open_ai.classic_control import continuous_mountain_car
|
0
alr_envs/open_ai/classic_control/__init__.py
Normal file
0
alr_envs/open_ai/classic_control/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .mp_wrapper import MPWrapper
|
@ -1,7 +1,7 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class MPWrapper(MPEnvWrapper):
|
@ -1 +0,0 @@
|
||||
from alr_envs.open_ai.continuous_mountain_car.mp_wrapper import MPWrapper
|
@ -1 +0,0 @@
|
||||
from alr_envs.open_ai.fetch.mp_wrapper import MPWrapper
|
@ -1,22 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
|
||||
|
||||
class MPWrapper(MPEnvWrapper):
|
||||
@property
|
||||
def current_vel(self) -> Union[float, int, np.ndarray]:
|
||||
return self.unwrapped._get_obs()["observation"][-5:-1]
|
||||
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||
return self.unwrapped._get_obs()["observation"][:4]
|
||||
|
||||
@property
|
||||
def goal_pos(self):
|
||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
0
alr_envs/open_ai/mujoco/__init__.py
Normal file
0
alr_envs/open_ai/mujoco/__init__.py
Normal file
1
alr_envs/open_ai/mujoco/reacher_v2/__init__.py
Normal file
1
alr_envs/open_ai/mujoco/reacher_v2/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .mp_wrapper import MPWrapper
|
@ -1,7 +1,7 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class MPWrapper(MPEnvWrapper):
|
@ -1 +0,0 @@
|
||||
from alr_envs.open_ai.reacher_v2.mp_wrapper import MPWrapper
|
0
alr_envs/open_ai/robotics/__init__.py
Normal file
0
alr_envs/open_ai/robotics/__init__.py
Normal file
1
alr_envs/open_ai/robotics/fetch/__init__.py
Normal file
1
alr_envs/open_ai/robotics/fetch/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .mp_wrapper import MPWrapper
|
49
alr_envs/open_ai/robotics/fetch/mp_wrapper.py
Normal file
49
alr_envs/open_ai/robotics/fetch/mp_wrapper.py
Normal file
@ -0,0 +1,49 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
|
||||
class MPWrapper(MPEnvWrapper):
|
||||
|
||||
@property
|
||||
def active_obs(self):
|
||||
return np.hstack([
|
||||
[False] * 3, # achieved goal
|
||||
[True] * 3, # desired/true goal
|
||||
[False] * 3, # grip pos
|
||||
[True, True, False] * int(self.has_object), # object position
|
||||
[True, True, False] * int(self.has_object), # object relative position
|
||||
[False] * 2, # gripper state
|
||||
[False] * 3 * int(self.has_object), # object rotation
|
||||
[False] * 3 * int(self.has_object), # object velocity position
|
||||
[False] * 3 * int(self.has_object), # object velocity rotation
|
||||
[False] * 3, # grip velocity position
|
||||
[False] * 2, # gripper velocity
|
||||
]).astype(bool)
|
||||
|
||||
@property
|
||||
def current_vel(self) -> Union[float, int, np.ndarray]:
|
||||
dt = self.sim.nsubsteps * self.sim.model.opt.timestep
|
||||
grip_velp = self.sim.data.get_site_xvelp("robot0:grip") * dt
|
||||
# gripper state should be symmetric for left and right.
|
||||
# They are controlled with only one action for both gripper joints
|
||||
gripper_state = self.sim.data.get_joint_qvel('robot0:r_gripper_finger_joint') * dt
|
||||
return np.hstack([grip_velp, gripper_state])
|
||||
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||
grip_pos = self.sim.data.get_site_xpos("robot0:grip")
|
||||
# gripper state should be symmetric for left and right.
|
||||
# They are controlled with only one action for both gripper joints
|
||||
gripper_state = self.sim.data.get_joint_qpos('robot0:r_gripper_finger_joint')
|
||||
return np.hstack([grip_pos, gripper_state])
|
||||
|
||||
@property
|
||||
def goal_pos(self):
|
||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
@ -37,7 +37,6 @@ def make(
|
||||
episode_length = 250 if domain_name == "manipulation" else 1000
|
||||
|
||||
max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
|
||||
|
||||
if env_id not in gym.envs.registry.env_specs:
|
||||
task_kwargs = {'random': seed}
|
||||
# if seed is not None:
|
||||
@ -46,7 +45,7 @@ def make(
|
||||
task_kwargs['time_limit'] = time_limit
|
||||
register(
|
||||
id=env_id,
|
||||
entry_point='alr_envs.utils.dmc2gym_wrapper:DMCWrapper',
|
||||
entry_point='alr_envs.utils.dmc_wrapper:DMCWrapper',
|
||||
kwargs=dict(
|
||||
domain_name=domain_name,
|
||||
task_name=task_name,
|
||||
|
@ -33,11 +33,14 @@ def _spec_to_box(spec):
|
||||
|
||||
|
||||
def _flatten_obs(obs: collections.MutableMapping):
|
||||
# obs_pieces = []
|
||||
# for v in obs.values():
|
||||
# flat = np.array([v]) if np.isscalar(v) else v.ravel()
|
||||
# obs_pieces.append(flat)
|
||||
# return np.concatenate(obs_pieces, axis=0)
|
||||
"""
|
||||
Flattens an observation of type MutableMapping, e.g. a dict to a 1D array.
|
||||
Args:
|
||||
obs: observation to flatten
|
||||
|
||||
Returns: 1D array of observation
|
||||
|
||||
"""
|
||||
|
||||
if not isinstance(obs, collections.MutableMapping):
|
||||
raise ValueError(f'Requires dict-like observations structure. {type(obs)} found.')
|
||||
@ -52,19 +55,19 @@ def _flatten_obs(obs: collections.MutableMapping):
|
||||
class DMCWrapper(core.Env):
|
||||
def __init__(
|
||||
self,
|
||||
domain_name,
|
||||
task_name,
|
||||
task_kwargs={},
|
||||
visualize_reward=True,
|
||||
from_pixels=False,
|
||||
height=84,
|
||||
width=84,
|
||||
camera_id=0,
|
||||
frame_skip=1,
|
||||
environment_kwargs=None,
|
||||
channels_first=True
|
||||
domain_name: str,
|
||||
task_name: str,
|
||||
task_kwargs: dict = {},
|
||||
visualize_reward: bool = True,
|
||||
from_pixels: bool = False,
|
||||
height: int = 84,
|
||||
width: int = 84,
|
||||
camera_id: int = 0,
|
||||
frame_skip: int = 1,
|
||||
environment_kwargs: dict = None,
|
||||
channels_first: bool = True
|
||||
):
|
||||
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
|
||||
assert 'random' in task_kwargs, 'Please specify a seed for deterministic behavior.'
|
||||
self._from_pixels = from_pixels
|
||||
self._height = height
|
||||
self._width = width
|
||||
@ -74,7 +77,7 @@ class DMCWrapper(core.Env):
|
||||
|
||||
# create task
|
||||
if domain_name == "manipulation":
|
||||
assert not from_pixels, \
|
||||
assert not from_pixels and not task_name.endswith("_vision"), \
|
||||
"TODO: Vision interface for manipulation is different to suite and needs to be implemented"
|
||||
self._env = manipulation.load(environment_name=task_name, seed=task_kwargs['random'])
|
||||
else:
|
||||
@ -169,11 +172,12 @@ class DMCWrapper(core.Env):
|
||||
if self._last_state is None:
|
||||
raise ValueError('Environment not ready to render. Call reset() first.')
|
||||
|
||||
camera_id = camera_id or self._camera_id
|
||||
|
||||
# assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
|
||||
if mode == "rgb_array":
|
||||
height = height or self._height
|
||||
width = width or self._width
|
||||
camera_id = camera_id or self._camera_id
|
||||
return self._env.physics.render(height=height, width=width, camera_id=camera_id)
|
||||
|
||||
elif mode == 'human':
|
||||
@ -184,7 +188,8 @@ class DMCWrapper(core.Env):
|
||||
self.viewer = rendering.SimpleImageViewer()
|
||||
# Render max available buffer size. Larger is only possible by altering the XML.
|
||||
img = self._env.physics.render(height=self._env.physics.model.vis.global_.offheight,
|
||||
width=self._env.physics.model.vis.global_.offwidth)
|
||||
width=self._env.physics.model.vis.global_.offwidth,
|
||||
camera_id=camera_id)
|
||||
self.viewer.imshow(img)
|
||||
return self.viewer.isopen
|
||||
|
@ -2,13 +2,14 @@ import logging
|
||||
from typing import Iterable, List, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api import MPEnvWrapper
|
||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||
|
||||
|
||||
def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs):
|
||||
def make_env_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
||||
"""
|
||||
TODO: Do we need this?
|
||||
Generate a callable to create a new gym environment with a given seed.
|
||||
@ -22,11 +23,16 @@ def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs):
|
||||
env_id: name of the environment
|
||||
seed: seed for deterministic behaviour
|
||||
rank: environment rank for deterministic over multiple seeds behaviour
|
||||
return_callable: If True returns a callable to create the environment instead of the environment itself.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return lambda: make_env(env_id, seed + rank, **kwargs)
|
||||
|
||||
def f():
|
||||
return make_env(env_id, seed + rank, **kwargs)
|
||||
|
||||
return f if return_callable else f()
|
||||
|
||||
|
||||
def make_env(env_id: str, seed, **kwargs):
|
||||
@ -103,6 +109,9 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs
|
||||
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||
|
||||
verify_dof(_env, mp_kwargs.get("num_dof"))
|
||||
|
||||
return DmpWrapper(_env, **mp_kwargs)
|
||||
|
||||
|
||||
@ -120,6 +129,9 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa
|
||||
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||
|
||||
verify_dof(_env, mp_kwargs.get("num_dof"))
|
||||
|
||||
return DetPMPWrapper(_env, **mp_kwargs)
|
||||
|
||||
|
||||
@ -185,5 +197,12 @@ def verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[N
|
||||
"""
|
||||
if mp_time_limit is not None and env_time_limit is not None:
|
||||
assert mp_time_limit == env_time_limit, \
|
||||
f"The manually specified 'time_limit' of {env_time_limit}s does not match " \
|
||||
f"The specified 'time_limit' of {env_time_limit}s does not match " \
|
||||
f"the duration of {mp_time_limit}s for the MP."
|
||||
|
||||
|
||||
def verify_dof(base_env: gym.Env, dof: int):
|
||||
action_shape = np.prod(base_env.action_space.shape)
|
||||
assert dof == action_shape, \
|
||||
f"The specified degrees of freedom ('num_dof') {dof} do not match " \
|
||||
f"the action space of {action_shape} the base environments"
|
||||
|
@ -19,4 +19,3 @@ def angle_normalize(x, type="deg"):
|
||||
|
||||
two_pi = 2 * np.pi
|
||||
return x - two_pi * np.floor((x + np.pi) / two_pi)
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -3,7 +3,7 @@ from setuptools import setup
|
||||
setup(
|
||||
name='alr_envs',
|
||||
version='0.0.1',
|
||||
packages=['alr_envs', 'alr_envs.classic_control', 'alr_envs.open_ai', 'alr_envs.mujoco', 'alr_envs.stochastic_search',
|
||||
packages=['alr_envs', 'alr_envs.classic_control', 'alr_envs.open_ai', 'alr_envs.mujoco', 'alr_envs.dmc',
|
||||
'alr_envs.utils'],
|
||||
install_requires=[
|
||||
'gym',
|
||||
|
@ -88,11 +88,8 @@ class TestEnvironments(unittest.TestCase):
|
||||
def test_environment_determinism(self):
|
||||
"""Tests that identical seeds produce identical trajectories."""
|
||||
seed = 0
|
||||
# Iterate over two trajectories generated using identical sequences of
|
||||
# random actions, and with identical task random states. Check that the
|
||||
# observations, rewards, discounts and step types are identical.
|
||||
# Iterate over two trajectories, which should have the same state and action sequence
|
||||
for spec in ALL_SPECS:
|
||||
# try:
|
||||
with self.subTest(msg=spec.id):
|
||||
self._run_env(spec.id)
|
||||
traj1 = self._run_env(spec.id, seed=seed)
|
||||
|
Loading…
Reference in New Issue
Block a user