diff --git a/alr_envs/__init__.py b/alr_envs/__init__.py index 9fb3ae2..5664d7c 100644 --- a/alr_envs/__init__.py +++ b/alr_envs/__init__.py @@ -236,7 +236,7 @@ for v in versions: "mp_kwargs": { "num_dof": 2 if "long" not in v.lower() else 5, "num_basis": 5, - "duration": 2, + "duration": 20, "width": 0.025, "policy_type": "velocity", "weights_scale": 0.2, @@ -492,7 +492,7 @@ register( ) ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-hard_detpmp-v0") -dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse", "two_poles", "three_poles"] +dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"] for task in dmc_cartpole_tasks: env_id = f'dmc_cartpole-{task}_dmp-v0' @@ -552,6 +552,120 @@ for task in dmc_cartpole_tasks: ) ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id) +env_id = f'dmc_cartpole-two_poles_dmp-v0' +register( + id=env_id, + entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper', + # max_episode_steps=1, + kwargs={ + "name": f"cartpole-two_poles", + # "time_limit": 1, + "camera_id": 0, + "episode_length": 1000, + "wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper], + "mp_kwargs": { + "num_dof": 1, + "num_basis": 5, + "duration": 10, + "learn_goal": True, + "alpha_phase": 2, + "bandwidth_factor": 2, + "policy_type": "motor", + "weights_scale": 50, + "goal_scale": 0.1, + "policy_kwargs": { + "p_gains": 10, + "d_gains": 10 + } + } + } +) +ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(env_id) + +env_id = f'dmc_cartpole-two_poles_detpmp-v0' +register( + id=env_id, + entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', + kwargs={ + "name": f"cartpole-two_poles", + # "time_limit": 1, + "camera_id": 0, + "episode_length": 1000, + "wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper], + "mp_kwargs": { + "num_dof": 1, + "num_basis": 5, + "duration": 10, + "width": 0.025, + "policy_type": "motor", + "weights_scale": 0.2, + "zero_start": True, + "policy_kwargs": { + "p_gains": 10, + "d_gains": 10 + } + } + } +) +ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id) + +env_id = f'dmc_cartpole-three_poles_dmp-v0' +register( + id=env_id, + entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper', + # max_episode_steps=1, + kwargs={ + "name": f"cartpole-three_poles", + # "time_limit": 1, + "camera_id": 0, + "episode_length": 1000, + "wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper], + "mp_kwargs": { + "num_dof": 1, + "num_basis": 5, + "duration": 10, + "learn_goal": True, + "alpha_phase": 2, + "bandwidth_factor": 2, + "policy_type": "motor", + "weights_scale": 50, + "goal_scale": 0.1, + "policy_kwargs": { + "p_gains": 10, + "d_gains": 10 + } + } + } +) +ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(env_id) + +env_id = f'dmc_cartpole-three_poles_detpmp-v0' +register( + id=env_id, + entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', + kwargs={ + "name": f"cartpole-three_poles", + # "time_limit": 1, + "camera_id": 0, + "episode_length": 1000, + "wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper], + "mp_kwargs": { + "num_dof": 1, + "num_basis": 5, + "duration": 10, + "width": 0.025, + "policy_type": "motor", + "weights_scale": 0.2, + "zero_start": True, + "policy_kwargs": { + "p_gains": 10, + "d_gains": 10 + } + } + } +) +ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id) + ### Manipulation register( @@ -758,7 +872,7 @@ for task in object_change_envs: id=env_id, entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', kwargs={ - "name": env_id, + "name": task, "wrappers": [meta.object_change.MPWrapper], "mp_kwargs": { "num_dof": 4, @@ -792,7 +906,7 @@ for task in goal_and_object_change_envs: id=env_id, entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', kwargs={ - "name": env_id, + "name": task, "wrappers": [meta.goal_and_object_change.MPWrapper], "mp_kwargs": { "num_dof": 4, @@ -816,7 +930,7 @@ for task in goal_and_endeffector_change_envs: id=env_id, entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', kwargs={ - "name": env_id, + "name": task, "wrappers": [meta.goal_and_endeffector_change.MPWrapper], "mp_kwargs": { "num_dof": 4, diff --git a/alr_envs/dmc/__init__.py b/alr_envs/dmc/__init__.py index c5a343d..b1cf281 100644 --- a/alr_envs/dmc/__init__.py +++ b/alr_envs/dmc/__init__.py @@ -1,4 +1,3 @@ -# from alr_envs.dmc import manipulation, suite from alr_envs.dmc.suite import ball_in_cup from alr_envs.dmc.suite import reacher from alr_envs.dmc.suite import cartpole diff --git a/alr_envs/dmc/manipulation/__init__.py b/alr_envs/dmc/manipulation/__init__.py index e69de29..5d17d4d 100644 --- a/alr_envs/dmc/manipulation/__init__.py +++ b/alr_envs/dmc/manipulation/__init__.py @@ -0,0 +1 @@ +from alr_envs.dmc.manipulation import reach diff --git a/alr_envs/dmc/suite/__init__.py b/alr_envs/dmc/suite/__init__.py index e69de29..f889cad 100644 --- a/alr_envs/dmc/suite/__init__.py +++ b/alr_envs/dmc/suite/__init__.py @@ -0,0 +1 @@ +from alr_envs.dmc.suite import cartpole, ball_in_cup, reacher diff --git a/alr_envs/dmc/suite/cartpole/__init__.py b/alr_envs/dmc/suite/cartpole/__init__.py index 823077a..c5f9bee 100644 --- a/alr_envs/dmc/suite/cartpole/__init__.py +++ b/alr_envs/dmc/suite/cartpole/__init__.py @@ -1,3 +1,3 @@ from .mp_wrapper import MPWrapper from .mp_wrapper import TwoPolesMPWrapper -from .mp_wrapper import ThreePolesMPWrapper \ No newline at end of file +from .mp_wrapper import ThreePolesMPWrapper diff --git a/alr_envs/utils/make_env_helpers.py b/alr_envs/utils/make_env_helpers.py index 466f7cf..755b8ce 100644 --- a/alr_envs/utils/make_env_helpers.py +++ b/alr_envs/utils/make_env_helpers.py @@ -1,5 +1,4 @@ -import logging -from typing import Iterable, List, Type, Union +from typing import Iterable, Type, Union import gym import numpy as np diff --git a/test/test_envs.py b/test/test_envs.py index 8c21295..bf12693 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -99,6 +99,12 @@ class TestEnvironments(unittest.TestCase): self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.") self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.") + def test_environment_functionality_meta(self): + """Tests that environments runs without errors using random actions.""" + for id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']: + with self.subTest(msg=id): + self._run_env(id) + if __name__ == '__main__': unittest.main()