fixed metaworld registration bug
This commit is contained in:
parent
5c70451018
commit
031dba541a
@ -236,7 +236,7 @@ for v in versions:
|
|||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2 if "long" not in v.lower() else 5,
|
"num_dof": 2 if "long" not in v.lower() else 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 20,
|
||||||
"width": 0.025,
|
"width": 0.025,
|
||||||
"policy_type": "velocity",
|
"policy_type": "velocity",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
@ -492,7 +492,7 @@ register(
|
|||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-hard_detpmp-v0")
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-hard_detpmp-v0")
|
||||||
|
|
||||||
dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse", "two_poles", "three_poles"]
|
dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
|
||||||
|
|
||||||
for task in dmc_cartpole_tasks:
|
for task in dmc_cartpole_tasks:
|
||||||
env_id = f'dmc_cartpole-{task}_dmp-v0'
|
env_id = f'dmc_cartpole-{task}_dmp-v0'
|
||||||
@ -552,6 +552,120 @@ for task in dmc_cartpole_tasks:
|
|||||||
)
|
)
|
||||||
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id)
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id)
|
||||||
|
|
||||||
|
env_id = f'dmc_cartpole-two_poles_dmp-v0'
|
||||||
|
register(
|
||||||
|
id=env_id,
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
|
# max_episode_steps=1,
|
||||||
|
kwargs={
|
||||||
|
"name": f"cartpole-two_poles",
|
||||||
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
|
"episode_length": 1000,
|
||||||
|
"wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 1,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 10,
|
||||||
|
"learn_goal": True,
|
||||||
|
"alpha_phase": 2,
|
||||||
|
"bandwidth_factor": 2,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 50,
|
||||||
|
"goal_scale": 0.1,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": 10,
|
||||||
|
"d_gains": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(env_id)
|
||||||
|
|
||||||
|
env_id = f'dmc_cartpole-two_poles_detpmp-v0'
|
||||||
|
register(
|
||||||
|
id=env_id,
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": f"cartpole-two_poles",
|
||||||
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
|
"episode_length": 1000,
|
||||||
|
"wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 1,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 10,
|
||||||
|
"width": 0.025,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": 10,
|
||||||
|
"d_gains": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id)
|
||||||
|
|
||||||
|
env_id = f'dmc_cartpole-three_poles_dmp-v0'
|
||||||
|
register(
|
||||||
|
id=env_id,
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
|
# max_episode_steps=1,
|
||||||
|
kwargs={
|
||||||
|
"name": f"cartpole-three_poles",
|
||||||
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
|
"episode_length": 1000,
|
||||||
|
"wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 1,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 10,
|
||||||
|
"learn_goal": True,
|
||||||
|
"alpha_phase": 2,
|
||||||
|
"bandwidth_factor": 2,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 50,
|
||||||
|
"goal_scale": 0.1,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": 10,
|
||||||
|
"d_gains": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(env_id)
|
||||||
|
|
||||||
|
env_id = f'dmc_cartpole-three_poles_detpmp-v0'
|
||||||
|
register(
|
||||||
|
id=env_id,
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": f"cartpole-three_poles",
|
||||||
|
# "time_limit": 1,
|
||||||
|
"camera_id": 0,
|
||||||
|
"episode_length": 1000,
|
||||||
|
"wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 1,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 10,
|
||||||
|
"width": 0.025,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": 10,
|
||||||
|
"d_gains": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(env_id)
|
||||||
|
|
||||||
### Manipulation
|
### Manipulation
|
||||||
|
|
||||||
register(
|
register(
|
||||||
@ -758,7 +872,7 @@ for task in object_change_envs:
|
|||||||
id=env_id,
|
id=env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": env_id,
|
"name": task,
|
||||||
"wrappers": [meta.object_change.MPWrapper],
|
"wrappers": [meta.object_change.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
@ -792,7 +906,7 @@ for task in goal_and_object_change_envs:
|
|||||||
id=env_id,
|
id=env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": env_id,
|
"name": task,
|
||||||
"wrappers": [meta.goal_and_object_change.MPWrapper],
|
"wrappers": [meta.goal_and_object_change.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
@ -816,7 +930,7 @@ for task in goal_and_endeffector_change_envs:
|
|||||||
id=env_id,
|
id=env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": env_id,
|
"name": task,
|
||||||
"wrappers": [meta.goal_and_endeffector_change.MPWrapper],
|
"wrappers": [meta.goal_and_endeffector_change.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# from alr_envs.dmc import manipulation, suite
|
|
||||||
from alr_envs.dmc.suite import ball_in_cup
|
from alr_envs.dmc.suite import ball_in_cup
|
||||||
from alr_envs.dmc.suite import reacher
|
from alr_envs.dmc.suite import reacher
|
||||||
from alr_envs.dmc.suite import cartpole
|
from alr_envs.dmc.suite import cartpole
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
from alr_envs.dmc.manipulation import reach
|
@ -0,0 +1 @@
|
|||||||
|
from alr_envs.dmc.suite import cartpole, ball_in_cup, reacher
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
from typing import Iterable, Type, Union
|
||||||
from typing import Iterable, List, Type, Union
|
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -99,6 +99,12 @@ class TestEnvironments(unittest.TestCase):
|
|||||||
self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
|
self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
|
||||||
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
|
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
|
||||||
|
|
||||||
|
def test_environment_functionality_meta(self):
|
||||||
|
"""Tests that environments runs without errors using random actions."""
|
||||||
|
for id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
||||||
|
with self.subTest(msg=id):
|
||||||
|
self._run_env(id)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user