fixed metaworld registration bug

This commit is contained in:
ottofabian 2021-08-25 11:43:32 +02:00
parent 5c70451018
commit 031dba541a
7 changed files with 129 additions and 9 deletions

View File

@ -236,7 +236,7 @@ for v in versions:
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2 if "long" not in v.lower() else 5, "num_dof": 2 if "long" not in v.lower() else 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 20,
"width": 0.025, "width": 0.025,
"policy_type": "velocity", "policy_type": "velocity",
"weights_scale": 0.2, "weights_scale": 0.2,
@ -492,7 +492,7 @@ register(
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-hard_detpmp-v0") ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-hard_detpmp-v0")
dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse", "two_poles", "three_poles"] dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
for task in dmc_cartpole_tasks: for task in dmc_cartpole_tasks:
env_id = f'dmc_cartpole-{task}_dmp-v0' env_id = f'dmc_cartpole-{task}_dmp-v0'
@ -552,6 +552,120 @@ for task in dmc_cartpole_tasks:
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id) ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id)
env_id = f'dmc_cartpole-two_poles_dmp-v0'
register(
id=env_id,
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": f"cartpole-two_poles",
# "time_limit": 1,
"camera_id": 0,
"episode_length": 1000,
"wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper],
"mp_kwargs": {
"num_dof": 1,
"num_basis": 5,
"duration": 10,
"learn_goal": True,
"alpha_phase": 2,
"bandwidth_factor": 2,
"policy_type": "motor",
"weights_scale": 50,
"goal_scale": 0.1,
"policy_kwargs": {
"p_gains": 10,
"d_gains": 10
}
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(env_id)
env_id = f'dmc_cartpole-two_poles_detpmp-v0'
register(
id=env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={
"name": f"cartpole-two_poles",
# "time_limit": 1,
"camera_id": 0,
"episode_length": 1000,
"wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper],
"mp_kwargs": {
"num_dof": 1,
"num_basis": 5,
"duration": 10,
"width": 0.025,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
"policy_kwargs": {
"p_gains": 10,
"d_gains": 10
}
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id)
env_id = f'dmc_cartpole-three_poles_dmp-v0'
register(
id=env_id,
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": f"cartpole-three_poles",
# "time_limit": 1,
"camera_id": 0,
"episode_length": 1000,
"wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper],
"mp_kwargs": {
"num_dof": 1,
"num_basis": 5,
"duration": 10,
"learn_goal": True,
"alpha_phase": 2,
"bandwidth_factor": 2,
"policy_type": "motor",
"weights_scale": 50,
"goal_scale": 0.1,
"policy_kwargs": {
"p_gains": 10,
"d_gains": 10
}
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(env_id)
env_id = f'dmc_cartpole-three_poles_detpmp-v0'
register(
id=env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={
"name": f"cartpole-three_poles",
# "time_limit": 1,
"camera_id": 0,
"episode_length": 1000,
"wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper],
"mp_kwargs": {
"num_dof": 1,
"num_basis": 5,
"duration": 10,
"width": 0.025,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
"policy_kwargs": {
"p_gains": 10,
"d_gains": 10
}
}
}
)
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id)
### Manipulation ### Manipulation
register( register(
@ -758,7 +872,7 @@ for task in object_change_envs:
id=env_id, id=env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": env_id, "name": task,
"wrappers": [meta.object_change.MPWrapper], "wrappers": [meta.object_change.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 4, "num_dof": 4,
@ -792,7 +906,7 @@ for task in goal_and_object_change_envs:
id=env_id, id=env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": env_id, "name": task,
"wrappers": [meta.goal_and_object_change.MPWrapper], "wrappers": [meta.goal_and_object_change.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 4, "num_dof": 4,
@ -816,7 +930,7 @@ for task in goal_and_endeffector_change_envs:
id=env_id, id=env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": env_id, "name": task,
"wrappers": [meta.goal_and_endeffector_change.MPWrapper], "wrappers": [meta.goal_and_endeffector_change.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 4, "num_dof": 4,

View File

@ -1,4 +1,3 @@
# from alr_envs.dmc import manipulation, suite
from alr_envs.dmc.suite import ball_in_cup from alr_envs.dmc.suite import ball_in_cup
from alr_envs.dmc.suite import reacher from alr_envs.dmc.suite import reacher
from alr_envs.dmc.suite import cartpole from alr_envs.dmc.suite import cartpole

View File

@ -0,0 +1 @@
from alr_envs.dmc.manipulation import reach

View File

@ -0,0 +1 @@
from alr_envs.dmc.suite import cartpole, ball_in_cup, reacher

View File

@ -1,5 +1,4 @@
import logging from typing import Iterable, Type, Union
from typing import Iterable, List, Type, Union
import gym import gym
import numpy as np import numpy as np

View File

@ -99,6 +99,12 @@ class TestEnvironments(unittest.TestCase):
self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.") self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.") self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
def test_environment_functionality_meta(self):
"""Tests that environments runs without errors using random actions."""
for id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
with self.subTest(msg=id):
self._run_env(id)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()