diff --git a/alr_envs/__init__.py b/alr_envs/__init__.py index cfa6251..5a92082 100644 --- a/alr_envs/__init__.py +++ b/alr_envs/__init__.py @@ -1,15 +1,12 @@ -import numpy as np from gym.envs.registration import register +from gym.wrappers import FlattenObservation -from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper -from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper -from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper -from alr_envs.dmc.manipulation.reach.reach_mp_wrapper import DMCReachSiteMPWrapper -from alr_envs.dmc.suite.ball_in_cup.ball_in_cup_mp_wrapper import DMCBallInCupMPWrapper -from alr_envs.dmc.suite.cartpole.cartpole_mp_wrapper import DMCCartpoleMPWrapper, DMCCartpoleThreePolesMPWrapper, \ - DMCCartpoleTwoPolesMPWrapper -from alr_envs.open_ai import reacher_v2, continuous_mountain_car, fetch -from alr_envs.dmc.suite.reacher.reacher_mp_wrapper import DMCReacherMPWrapper +from alr_envs import classic_control, dmc, open_ai + +from alr_envs.utils.make_env_helpers import make_dmp_env +from alr_envs.utils.make_env_helpers import make_detpmp_env +from alr_envs.utils.make_env_helpers import make_env +from alr_envs.utils.make_env_helpers import make_env_rank # Mujoco @@ -206,7 +203,7 @@ for v in versions: # max_episode_steps=1, kwargs={ "name": f"alr_envs:{v}", - "wrappers": [SimpleReacherMPWrapper], + "wrappers": [classic_control.simple_reacher.MPWrapper], "mp_kwargs": { "num_dof": 2 if "long" not in v.lower() else 5, "num_basis": 5, @@ -225,7 +222,7 @@ register( # max_episode_steps=1, kwargs={ "name": "alr_envs:ViaPointReacher-v0", - "wrappers": [ViaPointReacherMPWrapper], + "wrappers": [classic_control.viapoint_reacher.MPWrapper], "mp_kwargs": { "num_dof": 5, "num_basis": 5, @@ -238,6 +235,25 @@ register( } ) +register( + id='ViaPointReacherDetPMP-v0', + entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', + # max_episode_steps=1, + kwargs={ + "name": "alr_envs:ViaPointReacher-v0", + "wrappers": [classic_control.viapoint_reacher.MPWrapper], + "mp_kwargs": { + "num_dof": 5, + "num_basis": 5, + "duration": 2, + "width": 0.025, + "policy_type": "velocity", + "weights_scale": 0.2, + "zero_start": True + } + } +) + ## Hole Reacher versions = ["v0", "v1", "v2"] for v in versions: @@ -247,7 +263,7 @@ for v in versions: # max_episode_steps=1, kwargs={ "name": f"alr_envs:HoleReacher-{v}", - "wrappers": [HoleReacherMPWrapper], + "wrappers": [classic_control.hole_reacher.MPWrapper], "mp_kwargs": { "num_dof": 5, "num_basis": 5, @@ -267,7 +283,7 @@ for v in versions: entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', kwargs={ "name": f"alr_envs:HoleReacher-{v}", - "wrappers": [HoleReacherMPWrapper], + "wrappers": [classic_control.hole_reacher.MPWrapper], "mp_kwargs": { "num_dof": 5, "num_basis": 5, @@ -283,11 +299,6 @@ for v in versions: ## Deep Mind Control Suite (DMC) ### Suite -# tasks = ["ball_in_cup-catch", "reacher-easy", "reacher-hard", "cartpole-balance", "cartpole-balance_sparse", -# "cartpole-swingup", "cartpole-swingup_sparse", "cartpole-two_poles", "cartpole-three_poles"] -# wrappers = [DMCBallInCupMPWrapper, DMCReacherMPWrapper, DMCReacherMPWrapper, DMCCartpoleMPWrapper, -# partial(DMCCartpoleMPWrapper)] -# for t, w in zip(tasks, wrappers): register( id=f'dmc_ball_in_cup-catch_dmp-v0', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper', @@ -296,7 +307,7 @@ register( "name": f"ball_in_cup-catch", "time_limit": 1, "episode_length": 50, - "wrappers": [DMCBallInCupMPWrapper], + "wrappers": [dmc.suite.ball_in_cup.MPWrapper], "mp_kwargs": { "num_dof": 2, "num_basis": 5, @@ -322,7 +333,7 @@ register( "name": f"ball_in_cup-catch", "time_limit": 1, "episode_length": 50, - "wrappers": [DMCBallInCupMPWrapper], + "wrappers": [dmc.suite.ball_in_cup.MPWrapper], "mp_kwargs": { "num_dof": 2, "num_basis": 5, @@ -339,7 +350,7 @@ register( } ) -# TODO tune gains and episode length for all below +# TODO tune episode length for all below register( id=f'dmc_reacher-easy_dmp-v0', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper', @@ -348,7 +359,7 @@ register( "name": f"reacher-easy", "time_limit": 1, "episode_length": 50, - "wrappers": [DMCReacherMPWrapper], + "wrappers": [dmc.suite.reacher.MPWrapper], "mp_kwargs": { "num_dof": 2, "num_basis": 5, @@ -374,7 +385,7 @@ register( "name": f"reacher-easy", "time_limit": 1, "episode_length": 50, - "wrappers": [DMCReacherMPWrapper], + "wrappers": [dmc.suite.reacher.MPWrapper], "mp_kwargs": { "num_dof": 2, "num_basis": 5, @@ -399,7 +410,7 @@ register( "name": f"reacher-hard", "time_limit": 1, "episode_length": 50, - "wrappers": [DMCReacherMPWrapper], + "wrappers": [dmc.suite.reacher.MPWrapper], "mp_kwargs": { "num_dof": 2, "num_basis": 5, @@ -425,7 +436,7 @@ register( "name": f"reacher-hard", "time_limit": 1, "episode_length": 50, - "wrappers": [DMCReacherMPWrapper], + "wrappers": [dmc.suite.reacher.MPWrapper], "mp_kwargs": { "num_dof": 2, "num_basis": 5, @@ -448,8 +459,9 @@ register( kwargs={ "name": f"cartpole-balance", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, - "wrappers": [DMCCartpoleMPWrapper], + "wrappers": [dmc.suite.cartpole.MPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -461,8 +473,8 @@ register( "weights_scale": 50, "goal_scale": 0.1, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -474,8 +486,9 @@ register( kwargs={ "name": f"cartpole-balance", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, - "wrappers": [DMCCartpoleMPWrapper], + "wrappers": [dmc.suite.cartpole.MPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -485,8 +498,8 @@ register( "weights_scale": 0.2, "zero_start": True, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -498,8 +511,9 @@ register( kwargs={ "name": f"cartpole-balance_sparse", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, - "wrappers": [DMCCartpoleMPWrapper], + "wrappers": [dmc.suite.cartpole.MPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -511,8 +525,8 @@ register( "weights_scale": 50, "goal_scale": 0.1, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -524,8 +538,9 @@ register( kwargs={ "name": f"cartpole-balance_sparse", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, - "wrappers": [DMCCartpoleMPWrapper], + "wrappers": [dmc.suite.cartpole.MPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -535,8 +550,8 @@ register( "weights_scale": 0.2, "zero_start": True, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -549,8 +564,9 @@ register( kwargs={ "name": f"cartpole-swingup", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, - "wrappers": [DMCCartpoleMPWrapper], + "wrappers": [dmc.suite.cartpole.MPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -562,8 +578,8 @@ register( "weights_scale": 50, "goal_scale": 0.1, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -575,8 +591,9 @@ register( kwargs={ "name": f"cartpole-swingup", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, - "wrappers": [DMCCartpoleMPWrapper], + "wrappers": [dmc.suite.cartpole.MPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -586,8 +603,8 @@ register( "weights_scale": 0.2, "zero_start": True, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -599,8 +616,9 @@ register( kwargs={ "name": f"cartpole-swingup_sparse", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, - "wrappers": [DMCCartpoleMPWrapper], + "wrappers": [dmc.suite.cartpole.MPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -612,8 +630,8 @@ register( "weights_scale": 50, "goal_scale": 0.1, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -625,8 +643,9 @@ register( kwargs={ "name": f"cartpole-swingup_sparse", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, - "wrappers": [DMCCartpoleMPWrapper], + "wrappers": [dmc.suite.cartpole.MPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -636,8 +655,8 @@ register( "weights_scale": 0.2, "zero_start": True, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -649,9 +668,10 @@ register( kwargs={ "name": f"cartpole-two_poles", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, # "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)], - "wrappers": [DMCCartpoleTwoPolesMPWrapper], + "wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -663,8 +683,8 @@ register( "weights_scale": 50, "goal_scale": 0.1, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -676,9 +696,10 @@ register( kwargs={ "name": f"cartpole-two_poles", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, # "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)], - "wrappers": [DMCCartpoleTwoPolesMPWrapper], + "wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -688,8 +709,8 @@ register( "weights_scale": 0.2, "zero_start": True, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -701,9 +722,10 @@ register( kwargs={ "name": f"cartpole-three_poles", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, # "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)], - "wrappers": [DMCCartpoleThreePolesMPWrapper], + "wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -715,8 +737,8 @@ register( "weights_scale": 50, "goal_scale": 0.1, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -728,9 +750,10 @@ register( kwargs={ "name": f"cartpole-three_poles", # "time_limit": 1, + "camera_id": 0, "episode_length": 1000, # "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)], - "wrappers": [DMCCartpoleThreePolesMPWrapper], + "wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 5, @@ -740,8 +763,8 @@ register( "weights_scale": 0.2, "zero_start": True, "policy_kwargs": { - "p_gains": 50, - "d_gains": 1 + "p_gains": 10, + "d_gains": 10 } } } @@ -757,7 +780,7 @@ register( "name": f"manipulation-reach_site_features", # "time_limit": 1, "episode_length": 250, - "wrappers": [DMCReachSiteMPWrapper], + "wrappers": [dmc.manipulation.reach.MPWrapper], "mp_kwargs": { "num_dof": 9, "num_basis": 5, @@ -779,7 +802,7 @@ register( "name": f"manipulation-reach_site_features", # "time_limit": 1, "episode_length": 250, - "wrappers": [DMCReachSiteMPWrapper], + "wrappers": [dmc.manipulation.reach.MPWrapper], "mp_kwargs": { "num_dof": 9, "num_basis": 5, @@ -798,7 +821,7 @@ register( entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', kwargs={ "name": "gym.envs.classic_control:MountainCarContinuous-v0", - "wrappers": [continuous_mountain_car.MPWrapper], + "wrappers": [open_ai.classic_control.continuous_mountain_car.MPWrapper], "mp_kwargs": { "num_dof": 1, "num_basis": 4, @@ -819,7 +842,7 @@ register( entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', kwargs={ "name": "gym.envs.mujoco:Reacher-v2", - "wrappers": [reacher_v2.MPWrapper], + "wrappers": [open_ai.mujoco.reacher_v2.MPWrapper], "mp_kwargs": { "num_dof": 2, "num_basis": 6, @@ -840,7 +863,7 @@ register( entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', kwargs={ "name": "gym.envs.robotics:FetchSlideDense-v1", - "wrappers": [fetch.MPWrapper], + "wrappers": [FlattenObservation, open_ai.robotics.fetch.MPWrapper], "mp_kwargs": { "num_dof": 4, "num_basis": 5, @@ -857,7 +880,7 @@ register( entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', kwargs={ "name": "gym.envs.robotics:FetchReachDense-v1", - "wrappers": [fetch.MPWrapper], + "wrappers": [FlattenObservation, open_ai.robotics.fetch.MPWrapper], "mp_kwargs": { "num_dof": 4, "num_basis": 5, diff --git a/alr_envs/classic_control/hole_reacher/__init__.py b/alr_envs/classic_control/hole_reacher/__init__.py index e69de29..c5e6d2f 100644 --- a/alr_envs/classic_control/hole_reacher/__init__.py +++ b/alr_envs/classic_control/hole_reacher/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper diff --git a/alr_envs/classic_control/hole_reacher/hole_reacher_mp_wrapper.py b/alr_envs/classic_control/hole_reacher/mp_wrapper.py similarity index 90% rename from alr_envs/classic_control/hole_reacher/hole_reacher_mp_wrapper.py rename to alr_envs/classic_control/hole_reacher/mp_wrapper.py index 12b5d19..d951161 100644 --- a/alr_envs/classic_control/hole_reacher/hole_reacher_mp_wrapper.py +++ b/alr_envs/classic_control/hole_reacher/mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper -class HoleReacherMPWrapper(MPEnvWrapper): +class MPWrapper(MPEnvWrapper): @property def active_obs(self): return np.hstack([ diff --git a/alr_envs/classic_control/simple_reacher/__init__.py b/alr_envs/classic_control/simple_reacher/__init__.py index e69de29..989b5a9 100644 --- a/alr_envs/classic_control/simple_reacher/__init__.py +++ b/alr_envs/classic_control/simple_reacher/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/classic_control/simple_reacher/simple_reacher_mp_wrapper.py b/alr_envs/classic_control/simple_reacher/mp_wrapper.py similarity index 89% rename from alr_envs/classic_control/simple_reacher/simple_reacher_mp_wrapper.py rename to alr_envs/classic_control/simple_reacher/mp_wrapper.py index 40426cf..4b71e3a 100644 --- a/alr_envs/classic_control/simple_reacher/simple_reacher_mp_wrapper.py +++ b/alr_envs/classic_control/simple_reacher/mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper -class SimpleReacherMPWrapper(MPEnvWrapper): +class MPWrapper(MPEnvWrapper): @property def active_obs(self): return np.hstack([ diff --git a/alr_envs/classic_control/viapoint_reacher/__init__.py b/alr_envs/classic_control/viapoint_reacher/__init__.py index e69de29..989b5a9 100644 --- a/alr_envs/classic_control/viapoint_reacher/__init__.py +++ b/alr_envs/classic_control/viapoint_reacher/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/classic_control/viapoint_reacher/viapoint_reacher_mp_wrapper.py b/alr_envs/classic_control/viapoint_reacher/mp_wrapper.py similarity index 89% rename from alr_envs/classic_control/viapoint_reacher/viapoint_reacher_mp_wrapper.py rename to alr_envs/classic_control/viapoint_reacher/mp_wrapper.py index a4a6ba3..6b3e85d 100644 --- a/alr_envs/classic_control/viapoint_reacher/viapoint_reacher_mp_wrapper.py +++ b/alr_envs/classic_control/viapoint_reacher/mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper -class ViaPointReacherMPWrapper(MPEnvWrapper): +class MPWrapper(MPEnvWrapper): @property def active_obs(self): return np.hstack([ diff --git a/alr_envs/dmc/__init__.py b/alr_envs/dmc/__init__.py index e69de29..c5a343d 100644 --- a/alr_envs/dmc/__init__.py +++ b/alr_envs/dmc/__init__.py @@ -0,0 +1,5 @@ +# from alr_envs.dmc import manipulation, suite +from alr_envs.dmc.suite import ball_in_cup +from alr_envs.dmc.suite import reacher +from alr_envs.dmc.suite import cartpole +from alr_envs.dmc.manipulation import reach \ No newline at end of file diff --git a/alr_envs/dmc/manipulation/reach/__init__.py b/alr_envs/dmc/manipulation/reach/__init__.py index e69de29..989b5a9 100644 --- a/alr_envs/dmc/manipulation/reach/__init__.py +++ b/alr_envs/dmc/manipulation/reach/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/dmc/manipulation/reach/reach_mp_wrapper.py b/alr_envs/dmc/manipulation/reach/mp_wrapper.py similarity index 90% rename from alr_envs/dmc/manipulation/reach/reach_mp_wrapper.py rename to alr_envs/dmc/manipulation/reach/mp_wrapper.py index 612b44d..2d03f7b 100644 --- a/alr_envs/dmc/manipulation/reach/reach_mp_wrapper.py +++ b/alr_envs/dmc/manipulation/reach/mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper -class DMCReachSiteMPWrapper(MPEnvWrapper): +class MPWrapper(MPEnvWrapper): @property def active_obs(self): diff --git a/alr_envs/dmc/suite/ball_in_cup/__init__.py b/alr_envs/dmc/suite/ball_in_cup/__init__.py index e69de29..989b5a9 100644 --- a/alr_envs/dmc/suite/ball_in_cup/__init__.py +++ b/alr_envs/dmc/suite/ball_in_cup/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/dmc/suite/ball_in_cup/ball_in_cup_mp_wrapper.py b/alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py similarity index 90% rename from alr_envs/dmc/suite/ball_in_cup/ball_in_cup_mp_wrapper.py rename to alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py index 514f6f1..fb068b3 100644 --- a/alr_envs/dmc/suite/ball_in_cup/ball_in_cup_mp_wrapper.py +++ b/alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper -class DMCBallInCupMPWrapper(MPEnvWrapper): +class MPWrapper(MPEnvWrapper): @property def active_obs(self): diff --git a/alr_envs/dmc/suite/cartpole/__init__.py b/alr_envs/dmc/suite/cartpole/__init__.py index e69de29..823077a 100644 --- a/alr_envs/dmc/suite/cartpole/__init__.py +++ b/alr_envs/dmc/suite/cartpole/__init__.py @@ -0,0 +1,3 @@ +from .mp_wrapper import MPWrapper +from .mp_wrapper import TwoPolesMPWrapper +from .mp_wrapper import ThreePolesMPWrapper \ No newline at end of file diff --git a/alr_envs/dmc/suite/cartpole/cartpole_mp_wrapper.py b/alr_envs/dmc/suite/cartpole/mp_wrapper.py similarity index 83% rename from alr_envs/dmc/suite/cartpole/cartpole_mp_wrapper.py rename to alr_envs/dmc/suite/cartpole/mp_wrapper.py index d8f8493..1ca99f5 100644 --- a/alr_envs/dmc/suite/cartpole/cartpole_mp_wrapper.py +++ b/alr_envs/dmc/suite/cartpole/mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper -class DMCCartpoleMPWrapper(MPEnvWrapper): +class MPWrapper(MPEnvWrapper): def __init__(self, env, n_poles: int = 1): self.n_poles = n_poles @@ -39,13 +39,13 @@ class DMCCartpoleMPWrapper(MPEnvWrapper): return self.env.dt -class DMCCartpoleTwoPolesMPWrapper(DMCCartpoleMPWrapper): +class TwoPolesMPWrapper(MPWrapper): def __init__(self, env): super().__init__(env, n_poles=2) -class DMCCartpoleThreePolesMPWrapper(DMCCartpoleMPWrapper): +class ThreePolesMPWrapper(MPWrapper): def __init__(self, env): super().__init__(env, n_poles=3) diff --git a/alr_envs/dmc/suite/reacher/__init__.py b/alr_envs/dmc/suite/reacher/__init__.py index e69de29..989b5a9 100644 --- a/alr_envs/dmc/suite/reacher/__init__.py +++ b/alr_envs/dmc/suite/reacher/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/dmc/suite/reacher/reacher_mp_wrapper.py b/alr_envs/dmc/suite/reacher/mp_wrapper.py similarity index 88% rename from alr_envs/dmc/suite/reacher/reacher_mp_wrapper.py rename to alr_envs/dmc/suite/reacher/mp_wrapper.py index 17baf04..86bc992 100644 --- a/alr_envs/dmc/suite/reacher/reacher_mp_wrapper.py +++ b/alr_envs/dmc/suite/reacher/mp_wrapper.py @@ -2,10 +2,10 @@ from typing import Tuple, Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper -class DMCReacherMPWrapper(MPEnvWrapper): +class MPWrapper(MPEnvWrapper): @property def active_obs(self): diff --git a/alr_envs/examples/examples_dmc.py b/alr_envs/examples/examples_dmc.py index 6e24d83..b29329d 100644 --- a/alr_envs/examples/examples_dmc.py +++ b/alr_envs/examples/examples_dmc.py @@ -1,5 +1,5 @@ -from alr_envs.dmc.suite.ball_in_cup.ball_in_cup_mp_wrapper import DMCBallInCupMPWrapper -from alr_envs.utils.make_env_helpers import make_dmp_env, make_env +import alr_envs +from alr_envs.dmc.suite.ball_in_cup.mp_wrapper import MPWrapper def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True): @@ -17,13 +17,12 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True): Returns: """ - env = make_env(env_id, seed) + env = alr_envs.make_env(env_id, seed) rewards = 0 obs = env.reset() print("observation shape:", env.observation_space.shape) print("action shape:", env.action_space.shape) - # number of samples(multiple environment steps) for i in range(iterations): ac = env.action_space.sample() obs, reward, done, info = env.step(ac) @@ -63,7 +62,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): # Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper. # You can also add other gym.Wrappers in case they are needed. - wrappers = [DMCBallInCupMPWrapper] + wrappers = [MPWrapper] mp_kwargs = { "num_dof": 2, "num_basis": 5, @@ -84,9 +83,9 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): "episode_length": 1000, # "frame_skip": 1 } - env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) + env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) # OR for a deterministic ProMP: - # env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) + # env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) # This renders the full MP trajectory # It is only required to call render() once in the beginning, which renders every consecutive trajectory. diff --git a/alr_envs/examples/examples_general.py b/alr_envs/examples/examples_general.py index d237c03..88d79d5 100644 --- a/alr_envs/examples/examples_general.py +++ b/alr_envs/examples/examples_general.py @@ -1,11 +1,9 @@ -import warnings from collections import defaultdict import gym import numpy as np -from alr_envs.utils.make_env_helpers import make_env, make_env_rank -from alr_envs.utils.mp_env_async_sampler import AlrContextualMpEnvSampler, AlrMpEnvSampler, DummyDist +import alr_envs def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True): @@ -23,7 +21,7 @@ def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True): """ - env = make_env(env_id, seed) + env = alr_envs.make_env(env_id, seed) rewards = 0 obs = env.reset() print("Observation shape: ", env.observation_space.shape) @@ -58,7 +56,7 @@ def example_async(env_id="alr_envs:HoleReacher-v0", n_cpu=4, seed=int('533D', 16 Returns: Tuple of (obs, reward, done, info) with type np.ndarray """ - env = gym.vector.AsyncVectorEnv([make_env_rank(env_id, seed, i) for i in range(n_cpu)]) + env = gym.vector.AsyncVectorEnv([alr_envs.make_env_rank(env_id, seed, i) for i in range(n_cpu)]) # OR # envs = gym.vector.AsyncVectorEnv([make_env(env_id, seed + i) for i in range(n_cpu)]) diff --git a/alr_envs/examples/examples_motion_primitives.py b/alr_envs/examples/examples_motion_primitives.py index 2dca54e..480b58d 100644 --- a/alr_envs/examples/examples_motion_primitives.py +++ b/alr_envs/examples/examples_motion_primitives.py @@ -1,4 +1,4 @@ -from alr_envs import HoleReacherMPWrapper +from alr_envs import MPWrapper from alr_envs.utils.make_env_helpers import make_dmp_env, make_env @@ -113,7 +113,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): # Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper. # You can also add other gym.Wrappers in case they are needed. - wrappers = [HoleReacherMPWrapper] + wrappers = [MPWrapper] mp_kwargs = { "num_dof": 5, "num_basis": 5, diff --git a/alr_envs/examples/pd_control_gain_tuning.py b/alr_envs/examples/pd_control_gain_tuning.py new file mode 100644 index 0000000..55ab1c0 --- /dev/null +++ b/alr_envs/examples/pd_control_gain_tuning.py @@ -0,0 +1,74 @@ +import numpy as np +from matplotlib import pyplot as plt + +from alr_envs import dmc +from alr_envs.utils.make_env_helpers import make_detpmp_env + +# This might work for some environments, however, please verify either way the correct trajectory information +# for your environment are extracted below +SEED = 10 +env_id = "cartpole-swingup" +wrappers = [dmc.suite.cartpole.MPWrapper] + +mp_kwargs = { + "num_dof": 1, + "num_basis": 5, + "duration": 2, + "width": 0.025, + "policy_type": "motor", + "weights_scale": 0.2, + "zero_start": True, + "policy_kwargs": { + "p_gains": 10, + "d_gains": 10 # a good starting point is the sqrt of p_gains + } +} + +kwargs = dict(time_limit=2, episode_length=200) + +env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, + **kwargs) + +# Plot difference between real trajectory and target MP trajectory +env.reset() +pos, vel = env.mp_rollout(env.action_space.sample()) + +base_shape = env.full_action_space.shape +actual_pos = np.zeros((len(pos), *base_shape)) +actual_pos_ball = np.zeros((len(pos), *base_shape)) +actual_vel = np.zeros((len(pos), *base_shape)) +act = np.zeros((len(pos), *base_shape)) + +for t, pos_vel in enumerate(zip(pos, vel)): + actions = env.policy.get_action(pos_vel[0], pos_vel[1]) + actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high) + _, _, _, _ = env.env.step(actions) + act[t, :] = actions + # TODO verify for your environment + actual_pos[t, :] = env.current_pos + # actual_pos_ball[t, :] = env.physics.data.qpos[2:] + actual_vel[t, :] = env.current_vel + +plt.figure(figsize=(15, 5)) + +plt.subplot(131) +plt.title("Position") +plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))]) +# plt.plot(actual_pos_ball, label="true pos ball") +plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))]) +plt.xlabel("Episode steps") +plt.legend() + +plt.subplot(132) +plt.title("Velocity") +plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))]) +plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))]) +plt.xlabel("Episode steps") +plt.legend() + +plt.subplot(133) +plt.title("Actions") +plt.plot(act, c="C0"), # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))]) +plt.xlabel("Episode steps") +# plt.legend() +plt.show() diff --git a/alr_envs/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py b/alr_envs/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py index 321358a..945fa8d 100644 --- a/alr_envs/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py +++ b/alr_envs/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper class BallInACupMPWrapper(MPEnvWrapper): diff --git a/alr_envs/open_ai/__init__.py b/alr_envs/open_ai/__init__.py index e69de29..1e531cf 100644 --- a/alr_envs/open_ai/__init__.py +++ b/alr_envs/open_ai/__init__.py @@ -0,0 +1,3 @@ +from alr_envs.open_ai.mujoco import reacher_v2 +from alr_envs.open_ai.robotics import fetch +from alr_envs.open_ai.classic_control import continuous_mountain_car \ No newline at end of file diff --git a/alr_envs/open_ai/classic_control/__init__.py b/alr_envs/open_ai/classic_control/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alr_envs/open_ai/classic_control/continuous_mountain_car/__init__.py b/alr_envs/open_ai/classic_control/continuous_mountain_car/__init__.py new file mode 100644 index 0000000..989b5a9 --- /dev/null +++ b/alr_envs/open_ai/classic_control/continuous_mountain_car/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/open_ai/continuous_mountain_car/mp_wrapper.py b/alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py similarity index 85% rename from alr_envs/open_ai/continuous_mountain_car/mp_wrapper.py rename to alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py index 29378ed..2a2357a 100644 --- a/alr_envs/open_ai/continuous_mountain_car/mp_wrapper.py +++ b/alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py @@ -1,7 +1,7 @@ from typing import Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper class MPWrapper(MPEnvWrapper): diff --git a/alr_envs/open_ai/continuous_mountain_car/__init__.py b/alr_envs/open_ai/continuous_mountain_car/__init__.py deleted file mode 100644 index 36f731d..0000000 --- a/alr_envs/open_ai/continuous_mountain_car/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from alr_envs.open_ai.continuous_mountain_car.mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/open_ai/fetch/__init__.py b/alr_envs/open_ai/fetch/__init__.py deleted file mode 100644 index 2e68176..0000000 --- a/alr_envs/open_ai/fetch/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from alr_envs.open_ai.fetch.mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/open_ai/fetch/mp_wrapper.py b/alr_envs/open_ai/fetch/mp_wrapper.py deleted file mode 100644 index 6602a18..0000000 --- a/alr_envs/open_ai/fetch/mp_wrapper.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Union - -import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper - - -class MPWrapper(MPEnvWrapper): - @property - def current_vel(self) -> Union[float, int, np.ndarray]: - return self.unwrapped._get_obs()["observation"][-5:-1] - - @property - def current_pos(self) -> Union[float, int, np.ndarray]: - return self.unwrapped._get_obs()["observation"][:4] - - @property - def goal_pos(self): - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - - @property - def dt(self) -> Union[float, int]: - return self.env.dt \ No newline at end of file diff --git a/alr_envs/open_ai/mujoco/__init__.py b/alr_envs/open_ai/mujoco/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alr_envs/open_ai/mujoco/reacher_v2/__init__.py b/alr_envs/open_ai/mujoco/reacher_v2/__init__.py new file mode 100644 index 0000000..989b5a9 --- /dev/null +++ b/alr_envs/open_ai/mujoco/reacher_v2/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/open_ai/reacher_v2/mp_wrapper.py b/alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py similarity index 78% rename from alr_envs/open_ai/reacher_v2/mp_wrapper.py rename to alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py index d3181b5..16202e5 100644 --- a/alr_envs/open_ai/reacher_v2/mp_wrapper.py +++ b/alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py @@ -1,7 +1,7 @@ from typing import Union import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper class MPWrapper(MPEnvWrapper): diff --git a/alr_envs/open_ai/reacher_v2/__init__.py b/alr_envs/open_ai/reacher_v2/__init__.py deleted file mode 100644 index 48a5615..0000000 --- a/alr_envs/open_ai/reacher_v2/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from alr_envs.open_ai.reacher_v2.mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/open_ai/robotics/__init__.py b/alr_envs/open_ai/robotics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alr_envs/open_ai/robotics/fetch/__init__.py b/alr_envs/open_ai/robotics/fetch/__init__.py new file mode 100644 index 0000000..989b5a9 --- /dev/null +++ b/alr_envs/open_ai/robotics/fetch/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/open_ai/robotics/fetch/mp_wrapper.py b/alr_envs/open_ai/robotics/fetch/mp_wrapper.py new file mode 100644 index 0000000..218e175 --- /dev/null +++ b/alr_envs/open_ai/robotics/fetch/mp_wrapper.py @@ -0,0 +1,49 @@ +from typing import Union + +import numpy as np + +from mp_env_api import MPEnvWrapper + + +class MPWrapper(MPEnvWrapper): + + @property + def active_obs(self): + return np.hstack([ + [False] * 3, # achieved goal + [True] * 3, # desired/true goal + [False] * 3, # grip pos + [True, True, False] * int(self.has_object), # object position + [True, True, False] * int(self.has_object), # object relative position + [False] * 2, # gripper state + [False] * 3 * int(self.has_object), # object rotation + [False] * 3 * int(self.has_object), # object velocity position + [False] * 3 * int(self.has_object), # object velocity rotation + [False] * 3, # grip velocity position + [False] * 2, # gripper velocity + ]).astype(bool) + + @property + def current_vel(self) -> Union[float, int, np.ndarray]: + dt = self.sim.nsubsteps * self.sim.model.opt.timestep + grip_velp = self.sim.data.get_site_xvelp("robot0:grip") * dt + # gripper state should be symmetric for left and right. + # They are controlled with only one action for both gripper joints + gripper_state = self.sim.data.get_joint_qvel('robot0:r_gripper_finger_joint') * dt + return np.hstack([grip_velp, gripper_state]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray]: + grip_pos = self.sim.data.get_site_xpos("robot0:grip") + # gripper state should be symmetric for left and right. + # They are controlled with only one action for both gripper joints + gripper_state = self.sim.data.get_joint_qpos('robot0:r_gripper_finger_joint') + return np.hstack([grip_pos, gripper_state]) + + @property + def goal_pos(self): + raise ValueError("Goal position is not available and has to be learnt based on the environment.") + + @property + def dt(self) -> Union[float, int]: + return self.env.dt diff --git a/alr_envs/utils/__init__.py b/alr_envs/utils/__init__.py index 77fdd9f..758d49f 100644 --- a/alr_envs/utils/__init__.py +++ b/alr_envs/utils/__init__.py @@ -37,7 +37,6 @@ def make( episode_length = 250 if domain_name == "manipulation" else 1000 max_episode_steps = (episode_length + frame_skip - 1) // frame_skip - if env_id not in gym.envs.registry.env_specs: task_kwargs = {'random': seed} # if seed is not None: @@ -46,7 +45,7 @@ def make( task_kwargs['time_limit'] = time_limit register( id=env_id, - entry_point='alr_envs.utils.dmc2gym_wrapper:DMCWrapper', + entry_point='alr_envs.utils.dmc_wrapper:DMCWrapper', kwargs=dict( domain_name=domain_name, task_name=task_name, diff --git a/alr_envs/utils/dmc2gym_wrapper.py b/alr_envs/utils/dmc_wrapper.py similarity index 88% rename from alr_envs/utils/dmc2gym_wrapper.py rename to alr_envs/utils/dmc_wrapper.py index 5e6a53d..10f1af9 100644 --- a/alr_envs/utils/dmc2gym_wrapper.py +++ b/alr_envs/utils/dmc_wrapper.py @@ -33,11 +33,14 @@ def _spec_to_box(spec): def _flatten_obs(obs: collections.MutableMapping): - # obs_pieces = [] - # for v in obs.values(): - # flat = np.array([v]) if np.isscalar(v) else v.ravel() - # obs_pieces.append(flat) - # return np.concatenate(obs_pieces, axis=0) + """ + Flattens an observation of type MutableMapping, e.g. a dict to a 1D array. + Args: + obs: observation to flatten + + Returns: 1D array of observation + + """ if not isinstance(obs, collections.MutableMapping): raise ValueError(f'Requires dict-like observations structure. {type(obs)} found.') @@ -52,19 +55,19 @@ def _flatten_obs(obs: collections.MutableMapping): class DMCWrapper(core.Env): def __init__( self, - domain_name, - task_name, - task_kwargs={}, - visualize_reward=True, - from_pixels=False, - height=84, - width=84, - camera_id=0, - frame_skip=1, - environment_kwargs=None, - channels_first=True + domain_name: str, + task_name: str, + task_kwargs: dict = {}, + visualize_reward: bool = True, + from_pixels: bool = False, + height: int = 84, + width: int = 84, + camera_id: int = 0, + frame_skip: int = 1, + environment_kwargs: dict = None, + channels_first: bool = True ): - assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour' + assert 'random' in task_kwargs, 'Please specify a seed for deterministic behavior.' self._from_pixels = from_pixels self._height = height self._width = width @@ -74,7 +77,7 @@ class DMCWrapper(core.Env): # create task if domain_name == "manipulation": - assert not from_pixels, \ + assert not from_pixels and not task_name.endswith("_vision"), \ "TODO: Vision interface for manipulation is different to suite and needs to be implemented" self._env = manipulation.load(environment_name=task_name, seed=task_kwargs['random']) else: @@ -169,11 +172,12 @@ class DMCWrapper(core.Env): if self._last_state is None: raise ValueError('Environment not ready to render. Call reset() first.') + camera_id = camera_id or self._camera_id + # assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode if mode == "rgb_array": height = height or self._height width = width or self._width - camera_id = camera_id or self._camera_id return self._env.physics.render(height=height, width=width, camera_id=camera_id) elif mode == 'human': @@ -184,7 +188,8 @@ class DMCWrapper(core.Env): self.viewer = rendering.SimpleImageViewer() # Render max available buffer size. Larger is only possible by altering the XML. img = self._env.physics.render(height=self._env.physics.model.vis.global_.offheight, - width=self._env.physics.model.vis.global_.offwidth) + width=self._env.physics.model.vis.global_.offwidth, + camera_id=camera_id) self.viewer.imshow(img) return self.viewer.isopen diff --git a/alr_envs/utils/make_env_helpers.py b/alr_envs/utils/make_env_helpers.py index 0ba9dea..0348492 100644 --- a/alr_envs/utils/make_env_helpers.py +++ b/alr_envs/utils/make_env_helpers.py @@ -2,13 +2,14 @@ import logging from typing import Iterable, List, Type, Union import gym +import numpy as np -from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper +from mp_env_api import MPEnvWrapper from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper -def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs): +def make_env_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs): """ TODO: Do we need this? Generate a callable to create a new gym environment with a given seed. @@ -22,11 +23,16 @@ def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs): env_id: name of the environment seed: seed for deterministic behaviour rank: environment rank for deterministic over multiple seeds behaviour + return_callable: If True returns a callable to create the environment instead of the environment itself. Returns: """ - return lambda: make_env(env_id, seed + rank, **kwargs) + + def f(): + return make_env(env_id, seed + rank, **kwargs) + + return f if return_callable else f() def make_env(env_id: str, seed, **kwargs): @@ -103,6 +109,9 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) + + verify_dof(_env, mp_kwargs.get("num_dof")) + return DmpWrapper(_env, **mp_kwargs) @@ -120,6 +129,9 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) + + verify_dof(_env, mp_kwargs.get("num_dof")) + return DetPMPWrapper(_env, **mp_kwargs) @@ -185,5 +197,12 @@ def verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[N """ if mp_time_limit is not None and env_time_limit is not None: assert mp_time_limit == env_time_limit, \ - f"The manually specified 'time_limit' of {env_time_limit}s does not match " \ + f"The specified 'time_limit' of {env_time_limit}s does not match " \ f"the duration of {mp_time_limit}s for the MP." + + +def verify_dof(base_env: gym.Env, dof: int): + action_shape = np.prod(base_env.action_space.shape) + assert dof == action_shape, \ + f"The specified degrees of freedom ('num_dof') {dof} do not match " \ + f"the action space of {action_shape} the base environments" diff --git a/alr_envs/utils/utils.py b/alr_envs/utils/utils.py index 89205bd..3354db3 100644 --- a/alr_envs/utils/utils.py +++ b/alr_envs/utils/utils.py @@ -15,8 +15,7 @@ def angle_normalize(x, type="deg"): if type not in ["deg", "rad"]: raise ValueError(f"Invalid type {type}. Choose one of 'deg' or 'rad'.") if type == "deg": - x = np.deg2rad(x) # x * pi / 180 + x = np.deg2rad(x) # x * pi / 180 two_pi = 2 * np.pi return x - two_pi * np.floor((x + np.pi) / two_pi) - diff --git a/setup.py b/setup.py index 189be19..16374e4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup setup( name='alr_envs', version='0.0.1', - packages=['alr_envs', 'alr_envs.classic_control', 'alr_envs.open_ai', 'alr_envs.mujoco', 'alr_envs.stochastic_search', + packages=['alr_envs', 'alr_envs.classic_control', 'alr_envs.open_ai', 'alr_envs.mujoco', 'alr_envs.dmc', 'alr_envs.utils'], install_requires=[ 'gym', diff --git a/test/test_envs.py b/test/test_envs.py index addd1a3..2e5bf7e 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -88,11 +88,8 @@ class TestEnvironments(unittest.TestCase): def test_environment_determinism(self): """Tests that identical seeds produce identical trajectories.""" seed = 0 - # Iterate over two trajectories generated using identical sequences of - # random actions, and with identical task random states. Check that the - # observations, rewards, discounts and step types are identical. + # Iterate over two trajectories, which should have the same state and action sequence for spec in ALL_SPECS: - # try: with self.subTest(msg=spec.id): self._run_env(spec.id) traj1 = self._run_env(spec.id, seed=seed)