updated envs and registering
This commit is contained in:
parent
0a1e55d97b
commit
c8742e2934
@ -1,12 +1,16 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
|
|
||||||
|
from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper
|
||||||
|
from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper
|
||||||
|
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper
|
||||||
|
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_mp_wrapper import BallInACupMPWrapper
|
||||||
|
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_positional_wrapper import BallInACupPositionalWrapper
|
||||||
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
||||||
|
|
||||||
# from alr_envs.utils.mps.dmp_wrapper import DmpWrapper
|
|
||||||
|
|
||||||
# Mujoco
|
# Mujoco
|
||||||
|
|
||||||
|
## Reacher
|
||||||
register(
|
register(
|
||||||
id='ALRReacher-v0',
|
id='ALRReacher-v0',
|
||||||
entry_point='alr_envs.mujoco:ALRReacherEnv',
|
entry_point='alr_envs.mujoco:ALRReacherEnv',
|
||||||
@ -177,7 +181,7 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='ViaPointReacher-v0',
|
id='ViaPointReacher-v0',
|
||||||
entry_point='alr_envs.classic_control.viapoint_reacher:ViaPointReacher',
|
entry_point='alr_envs.classic_control:ViaPointReacher',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -189,7 +193,7 @@ register(
|
|||||||
## Hole Reacher
|
## Hole Reacher
|
||||||
register(
|
register(
|
||||||
id='HoleReacher-v0',
|
id='HoleReacher-v0',
|
||||||
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv',
|
entry_point='alr_envs.classic_control:HoleReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -205,7 +209,7 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='HoleReacher-v1',
|
id='HoleReacher-v1',
|
||||||
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv',
|
entry_point='alr_envs.classic_control:HoleReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -221,7 +225,7 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='HoleReacher-v2',
|
id='HoleReacher-v2',
|
||||||
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv',
|
entry_point='alr_envs.classic_control:HoleReacherEnv',
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -236,39 +240,46 @@ register(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# MP environments
|
# MP environments
|
||||||
|
|
||||||
## Simple Reacher
|
## Simple Reacher
|
||||||
versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
|
versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
|
||||||
for v in versions:
|
for v in versions:
|
||||||
name = v.split("-")
|
name = v.split("-")
|
||||||
register(
|
register(
|
||||||
id=f'{name[0]}DMP-{name[1]}',
|
id=f'{name[0]}DMP-{name[1]}',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:{v}",
|
"name": f"alr_envs:{v}",
|
||||||
"num_dof": 2 if "long" not in v.lower() else 5,
|
"wrappers": [SimpleReacherMPWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 2,
|
"num_dof": 2 if "long" not in v.lower() else 5,
|
||||||
"alpha_phase": 2,
|
"num_basis": 5,
|
||||||
"learn_goal": True,
|
"duration": 2,
|
||||||
"policy_type": "velocity",
|
"alpha_phase": 2,
|
||||||
"weights_scale": 50,
|
"learn_goal": True,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 50,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ViaPointReacherDMP-v0',
|
id='ViaPointReacherDMP-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:ViaPointReacher-v0",
|
"name": "alr_envs:ViaPointReacher-v0",
|
||||||
"num_dof": 5,
|
"wrappers": [ViaPointReacherMPWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 2,
|
"num_dof": 5,
|
||||||
"alpha_phase": 2,
|
"num_basis": 5,
|
||||||
"learn_goal": False,
|
"duration": 2,
|
||||||
"policy_type": "velocity",
|
"learn_goal": True,
|
||||||
"weights_scale": 50,
|
"alpha_phase": 2,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 50,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -277,52 +288,61 @@ versions = ["v0", "v1", "v2"]
|
|||||||
for v in versions:
|
for v in versions:
|
||||||
register(
|
register(
|
||||||
id=f'HoleReacherDMP-{v}',
|
id=f'HoleReacherDMP-{v}',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:HoleReacher-{v}",
|
"name": f"alr_envs:HoleReacher-{v}",
|
||||||
"num_dof": 5,
|
"wrappers": [HoleReacherMPWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 2,
|
"num_dof": 5,
|
||||||
"learn_goal": True,
|
"num_basis": 5,
|
||||||
"alpha_phase": 2,
|
"duration": 2,
|
||||||
"bandwidth_factor": 2,
|
"learn_goal": True,
|
||||||
"policy_type": "velocity",
|
"alpha_phase": 2,
|
||||||
"weights_scale": 50,
|
"bandwidth_factor": 2,
|
||||||
"goal_scale": 0.1
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 50,
|
||||||
|
"goal_scale": 0.1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=f'HoleReacherDetPMP-{v}',
|
id=f'HoleReacherDetPMP-{v}',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:HoleReacher-{v}",
|
"name": f"alr_envs:HoleReacher-{v}",
|
||||||
"num_dof": 5,
|
"wrappers": [HoleReacherMPWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 2,
|
"num_dof": 5,
|
||||||
"width": 0.025,
|
"num_basis": 5,
|
||||||
"policy_type": "velocity",
|
"duration": 2,
|
||||||
"weights_scale": 0.2,
|
"width": 0.025,
|
||||||
"zero_start": True
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: properly add final_pos
|
# TODO: properly add final_pos
|
||||||
register(
|
register(
|
||||||
id='HoleReacherFixedGoalDMP-v0',
|
id='HoleReacherFixedGoalDMP-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:HoleReacher-v0",
|
"name": "alr_envs:HoleReacher-v0",
|
||||||
"num_dof": 5,
|
"wrappers": [HoleReacherMPWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 2,
|
"num_dof": 5,
|
||||||
"learn_goal": False,
|
"num_basis": 5,
|
||||||
"alpha_phase": 2,
|
"duration": 2,
|
||||||
"policy_type": "velocity",
|
"learn_goal": False,
|
||||||
"weights_scale": 50,
|
"alpha_phase": 2,
|
||||||
"goal_scale": 0.1
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 50,
|
||||||
|
"goal_scale": 0.1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -330,81 +350,101 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRBallInACupSimpleDMP-v0',
|
id='ALRBallInACupSimpleDMP-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:ALRBallInACupSimple-v0",
|
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||||
"num_dof": 3,
|
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 3.5,
|
"num_dof": 3,
|
||||||
"post_traj_time": 4.5,
|
"num_basis": 5,
|
||||||
"learn_goal": False,
|
"duration": 3.5,
|
||||||
"alpha_phase": 3,
|
"post_traj_time": 4.5,
|
||||||
"bandwidth_factor": 2.5,
|
"learn_goal": False,
|
||||||
"policy_type": "motor",
|
"alpha_phase": 3,
|
||||||
"weights_scale": 100,
|
"bandwidth_factor": 2.5,
|
||||||
"return_to_start": True,
|
"policy_type": "motor",
|
||||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
"weights_scale": 100,
|
||||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
"return_to_start": True,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRBallInACupDMP-v0',
|
id='ALRBallInACupDMP-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:ALRBallInACup-v0",
|
"name": "alr_envs:ALRBallInACup-v0",
|
||||||
"num_dof": 7,
|
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 3.5,
|
"num_dof": 7,
|
||||||
"post_traj_time": 4.5,
|
"num_basis": 5,
|
||||||
"learn_goal": False,
|
"duration": 3.5,
|
||||||
"alpha_phase": 3,
|
"post_traj_time": 4.5,
|
||||||
"bandwidth_factor": 2.5,
|
"learn_goal": False,
|
||||||
"policy_type": "motor",
|
"alpha_phase": 3,
|
||||||
"weights_scale": 100,
|
"bandwidth_factor": 2.5,
|
||||||
"return_to_start": True,
|
"policy_type": "motor",
|
||||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
"weights_scale": 100,
|
||||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
"return_to_start": True,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRBallInACupSimpleDetPMP-v0',
|
id='ALRBallInACupSimpleDetPMP-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:ALRBallInACupSimple-v0",
|
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||||
"num_dof": 3,
|
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 3.5,
|
"num_dof": 3,
|
||||||
"post_traj_time": 4.5,
|
"num_basis": 5,
|
||||||
"width": 0.0035,
|
"duration": 3.5,
|
||||||
# "off": -0.05,
|
"post_traj_time": 4.5,
|
||||||
"policy_type": "motor",
|
"width": 0.0035,
|
||||||
"weights_scale": 0.2,
|
# "off": -0.05,
|
||||||
"zero_start": True,
|
"policy_type": "motor",
|
||||||
"zero_goal": True,
|
"weights_scale": 0.2,
|
||||||
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
|
"zero_start": True,
|
||||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
"zero_goal": True,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRBallInACupPDSimpleDetPMP-v0',
|
id='ALRBallInACupPDSimpleDetPMP-v0',
|
||||||
entry_point='alr_envs.mujoco.ball_in_a_cup.biac_pd:make_detpmp_env',
|
entry_point='alr_envs.mujoco.ball_in_a_cup.biac_pd:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:ALRBallInACupPDSimple-v0",
|
"name": "alr_envs:ALRBallInACupPDSimple-v0",
|
||||||
"num_dof": 3,
|
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 3.5,
|
"num_dof": 3,
|
||||||
"post_traj_time": 4.5,
|
"num_basis": 5,
|
||||||
"width": 0.0035,
|
"duration": 3.5,
|
||||||
# "off": -0.05,
|
"post_traj_time": 4.5,
|
||||||
"policy_type": "motor",
|
"width": 0.0035,
|
||||||
"weights_scale": 0.2,
|
# "off": -0.05,
|
||||||
"zero_start": True,
|
"policy_type": "motor",
|
||||||
"zero_goal": True,
|
"weights_scale": 0.2,
|
||||||
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
|
"zero_start": True,
|
||||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
"zero_goal": True,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -430,20 +470,26 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRBallInACupDetPMP-v0',
|
id='ALRBallInACupDetPMP-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:ALRBallInACupSimple-v0",
|
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||||
"num_dof": 7,
|
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 3.5,
|
"num_dof": 7,
|
||||||
"post_traj_time": 4.5,
|
"num_basis": 5,
|
||||||
"width": 0.0035,
|
"duration": 3.5,
|
||||||
"policy_type": "motor",
|
"post_traj_time": 4.5,
|
||||||
"weights_scale": 0.2,
|
"width": 0.0035,
|
||||||
"zero_start": True,
|
"policy_type": "motor",
|
||||||
"zero_goal": True,
|
"weights_scale": 0.2,
|
||||||
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
|
"zero_start": True,
|
||||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
"zero_goal": True,
|
||||||
|
"policy_kwargs": {
|
||||||
|
|
||||||
|
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -452,18 +498,23 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_contextual_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_contextual_env',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:ALRBallInACupGoal-v0",
|
"name": "alr_envs:ALRBallInACupGoal-v0",
|
||||||
"num_dof": 7,
|
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||||
"num_basis": 5,
|
"mp_kwargs": {
|
||||||
"duration": 3.5,
|
"num_dof": 7,
|
||||||
"post_traj_time": 4.5,
|
"num_basis": 5,
|
||||||
"learn_goal": True,
|
"duration": 3.5,
|
||||||
"alpha_phase": 3,
|
"post_traj_time": 4.5,
|
||||||
"bandwidth_factor": 2.5,
|
"learn_goal": True,
|
||||||
"policy_type": "motor",
|
"alpha_phase": 3,
|
||||||
"weights_scale": 50,
|
"bandwidth_factor": 2.5,
|
||||||
"goal_scale": 0.1,
|
"policy_type": "motor",
|
||||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
"weights_scale": 50,
|
||||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
"goal_scale": 0.1,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
|
from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv
|
||||||
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacher
|
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacher
|
||||||
from alr_envs.classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
|
from alr_envs.classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
|
||||||
from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv
|
|
||||||
|
@ -21,14 +21,14 @@ class HoleReacherEnv(gym.Env):
|
|||||||
self.random_start = random_start
|
self.random_start = random_start
|
||||||
|
|
||||||
# provided initial parameters
|
# provided initial parameters
|
||||||
self.hole_x = hole_x # x-position of center of hole
|
self.initial_x = hole_x # x-position of center of hole
|
||||||
self.hole_width = hole_width # width of hole
|
self.initial_width = hole_width # width of hole
|
||||||
self.hole_depth = hole_depth # depth of hole
|
self.initial_depth = hole_depth # depth of hole
|
||||||
|
|
||||||
# temp container for current env state
|
# temp container for current env state
|
||||||
self._tmp_hole_x = None
|
self._tmp_x = None
|
||||||
self._tmp_hole_width = None
|
self._tmp_width = None
|
||||||
self._tmp_hole_depth = None
|
self._tmp_depth = None
|
||||||
self._goal = None # x-y coordinates for reaching the center at the bottom of the hole
|
self._goal = None # x-y coordinates for reaching the center at the bottom of the hole
|
||||||
|
|
||||||
# collision
|
# collision
|
||||||
@ -69,6 +69,10 @@ class HoleReacherEnv(gym.Env):
|
|||||||
def dt(self) -> Union[float, int]:
|
def dt(self) -> Union[float, int]:
|
||||||
return self._dt
|
return self._dt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_pos(self):
|
||||||
|
return self._start_pos
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
"""
|
"""
|
||||||
A single step with an action in joint velocity space
|
A single step with an action in joint velocity space
|
||||||
@ -110,13 +114,13 @@ class HoleReacherEnv(gym.Env):
|
|||||||
return self._get_obs().copy()
|
return self._get_obs().copy()
|
||||||
|
|
||||||
def _generate_hole(self):
|
def _generate_hole(self):
|
||||||
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self.hole_x is None else np.copy(self.hole_x)
|
self._tmp_x = self.np_random.uniform(1, 3.5, 1) if self.initial_x is None else np.copy(self.initial_x)
|
||||||
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self.hole_width is None else np.copy(
|
self._tmp_width = self.np_random.uniform(0.15, 0.5, 1) if self.initial_width is None else np.copy(
|
||||||
self.hole_width)
|
self.initial_width)
|
||||||
# TODO we do not want this right now.
|
# TODO we do not want this right now.
|
||||||
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self.hole_depth is None else np.copy(
|
self._tmp_depth = self.np_random.uniform(1, 1, 1) if self.initial_depth is None else np.copy(
|
||||||
self.hole_depth)
|
self.initial_depth)
|
||||||
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
|
self._goal = np.hstack([self._tmp_x, -self._tmp_depth])
|
||||||
|
|
||||||
def _update_joints(self):
|
def _update_joints(self):
|
||||||
"""
|
"""
|
||||||
@ -164,7 +168,7 @@ class HoleReacherEnv(gym.Env):
|
|||||||
np.cos(theta),
|
np.cos(theta),
|
||||||
np.sin(theta),
|
np.sin(theta),
|
||||||
self._angle_velocity,
|
self._angle_velocity,
|
||||||
self._tmp_hole_width,
|
self._tmp_width,
|
||||||
# self._tmp_hole_depth,
|
# self._tmp_hole_depth,
|
||||||
self.end_effector - self._goal,
|
self.end_effector - self._goal,
|
||||||
self._steps
|
self._steps
|
||||||
@ -192,7 +196,7 @@ class HoleReacherEnv(gym.Env):
|
|||||||
def _check_wall_collision(self, line_points):
|
def _check_wall_collision(self, line_points):
|
||||||
|
|
||||||
# all points that are before the hole in x
|
# all points that are before the hole in x
|
||||||
r, c = np.where(line_points[:, :, 0] < (self._tmp_hole_x - self._tmp_hole_width / 2))
|
r, c = np.where(line_points[:, :, 0] < (self._tmp_x - self._tmp_width / 2))
|
||||||
|
|
||||||
# check if any of those points are below surface
|
# check if any of those points are below surface
|
||||||
nr_line_points_below_surface_before_hole = np.sum(line_points[r, c, 1] < 0)
|
nr_line_points_below_surface_before_hole = np.sum(line_points[r, c, 1] < 0)
|
||||||
@ -201,7 +205,7 @@ class HoleReacherEnv(gym.Env):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# all points that are after the hole in x
|
# all points that are after the hole in x
|
||||||
r, c = np.where(line_points[:, :, 0] > (self._tmp_hole_x + self._tmp_hole_width / 2))
|
r, c = np.where(line_points[:, :, 0] > (self._tmp_x + self._tmp_width / 2))
|
||||||
|
|
||||||
# check if any of those points are below surface
|
# check if any of those points are below surface
|
||||||
nr_line_points_below_surface_after_hole = np.sum(line_points[r, c, 1] < 0)
|
nr_line_points_below_surface_after_hole = np.sum(line_points[r, c, 1] < 0)
|
||||||
@ -210,11 +214,11 @@ class HoleReacherEnv(gym.Env):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# all points that are above the hole
|
# all points that are above the hole
|
||||||
r, c = np.where((line_points[:, :, 0] > (self._tmp_hole_x - self._tmp_hole_width / 2)) & (
|
r, c = np.where((line_points[:, :, 0] > (self._tmp_x - self._tmp_width / 2)) & (
|
||||||
line_points[:, :, 0] < (self._tmp_hole_x + self._tmp_hole_width / 2)))
|
line_points[:, :, 0] < (self._tmp_x + self._tmp_width / 2)))
|
||||||
|
|
||||||
# check if any of those points are below surface
|
# check if any of those points are below surface
|
||||||
nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self._tmp_hole_depth)
|
nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self._tmp_depth)
|
||||||
|
|
||||||
if nr_line_points_below_surface_in_hole > 0:
|
if nr_line_points_below_surface_in_hole > 0:
|
||||||
return True
|
return True
|
||||||
@ -257,17 +261,17 @@ class HoleReacherEnv(gym.Env):
|
|||||||
def _set_patches(self):
|
def _set_patches(self):
|
||||||
if self.fig is not None:
|
if self.fig is not None:
|
||||||
self.fig.gca().patches = []
|
self.fig.gca().patches = []
|
||||||
left_block = patches.Rectangle((-self.n_links, -self._tmp_hole_depth),
|
left_block = patches.Rectangle((-self.n_links, -self._tmp_depth),
|
||||||
self.n_links + self._tmp_hole_x - self._tmp_hole_width / 2,
|
self.n_links + self._tmp_x - self._tmp_width / 2,
|
||||||
self._tmp_hole_depth,
|
self._tmp_depth,
|
||||||
fill=True, edgecolor='k', facecolor='k')
|
fill=True, edgecolor='k', facecolor='k')
|
||||||
right_block = patches.Rectangle((self._tmp_hole_x + self._tmp_hole_width / 2, -self._tmp_hole_depth),
|
right_block = patches.Rectangle((self._tmp_x + self._tmp_width / 2, -self._tmp_depth),
|
||||||
self.n_links - self._tmp_hole_x + self._tmp_hole_width / 2,
|
self.n_links - self._tmp_x + self._tmp_width / 2,
|
||||||
self._tmp_hole_depth,
|
self._tmp_depth,
|
||||||
fill=True, edgecolor='k', facecolor='k')
|
fill=True, edgecolor='k', facecolor='k')
|
||||||
hole_floor = patches.Rectangle((self._tmp_hole_x - self._tmp_hole_width / 2, -self._tmp_hole_depth),
|
hole_floor = patches.Rectangle((self._tmp_x - self._tmp_width / 2, -self._tmp_depth),
|
||||||
self._tmp_hole_width,
|
self._tmp_width,
|
||||||
1 - self._tmp_hole_depth,
|
1 - self._tmp_depth,
|
||||||
fill=True, edgecolor='k', facecolor='k')
|
fill=True, edgecolor='k', facecolor='k')
|
||||||
|
|
||||||
# Add the patch to the Axes
|
# Add the patch to the Axes
|
||||||
|
@ -12,7 +12,7 @@ class HoleReacherMPWrapper(MPEnvWrapper):
|
|||||||
[self.env.random_start] * self.env.n_links, # cos
|
[self.env.random_start] * self.env.n_links, # cos
|
||||||
[self.env.random_start] * self.env.n_links, # sin
|
[self.env.random_start] * self.env.n_links, # sin
|
||||||
[self.env.random_start] * self.env.n_links, # velocity
|
[self.env.random_start] * self.env.n_links, # velocity
|
||||||
[self.env.hole_width is None], # hole width
|
[self.env.initial_width is None], # hole width
|
||||||
# [self.env.hole_depth is None], # hole depth
|
# [self.env.hole_depth is None], # hole depth
|
||||||
[True] * 2, # x-y coordinates of target distance
|
[True] * 2, # x-y coordinates of target distance
|
||||||
[False] # env steps
|
[False] # env steps
|
||||||
@ -20,7 +20,7 @@ class HoleReacherMPWrapper(MPEnvWrapper):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def start_pos(self) -> Union[float, int, np.ndarray]:
|
def start_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
return self._start_pos
|
return self.env.start_pos
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
@ -22,15 +22,18 @@ class SimpleReacherEnv(gym.Env):
|
|||||||
|
|
||||||
self.random_start = random_start
|
self.random_start = random_start
|
||||||
|
|
||||||
|
# provided initial parameters
|
||||||
|
self.inital_target = target
|
||||||
|
|
||||||
|
# temp container for current env state
|
||||||
|
self._goal = None
|
||||||
|
|
||||||
self._joints = None
|
self._joints = None
|
||||||
self._joint_angles = None
|
self._joint_angles = None
|
||||||
self._angle_velocity = None
|
self._angle_velocity = None
|
||||||
self._start_pos = np.zeros(self.n_links)
|
self._start_pos = np.zeros(self.n_links)
|
||||||
self._start_vel = np.zeros(self.n_links)
|
self._start_vel = np.zeros(self.n_links)
|
||||||
|
|
||||||
self._target = target # provided target value
|
|
||||||
self._goal = None # updated goal value, does not change when target != None
|
|
||||||
|
|
||||||
self.max_torque = 1
|
self.max_torque = 1
|
||||||
self.steps_before_reward = 199
|
self.steps_before_reward = 199
|
||||||
|
|
||||||
@ -56,6 +59,10 @@ class SimpleReacherEnv(gym.Env):
|
|||||||
def dt(self) -> Union[float, int]:
|
def dt(self) -> Union[float, int]:
|
||||||
return self._dt
|
return self._dt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_pos(self):
|
||||||
|
return self._start_pos
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
"""
|
"""
|
||||||
A single step with action in torque space
|
A single step with action in torque space
|
||||||
@ -129,14 +136,14 @@ class SimpleReacherEnv(gym.Env):
|
|||||||
|
|
||||||
def _generate_goal(self):
|
def _generate_goal(self):
|
||||||
|
|
||||||
if self._target is None:
|
if self.inital_target is None:
|
||||||
|
|
||||||
total_length = np.sum(self.link_lengths)
|
total_length = np.sum(self.link_lengths)
|
||||||
goal = np.array([total_length, total_length])
|
goal = np.array([total_length, total_length])
|
||||||
while np.linalg.norm(goal) >= total_length:
|
while np.linalg.norm(goal) >= total_length:
|
||||||
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
|
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
|
||||||
else:
|
else:
|
||||||
goal = np.copy(self._target)
|
goal = np.copy(self.inital_target)
|
||||||
|
|
||||||
self._goal = goal
|
self._goal = goal
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ class SimpleReacherMPWrapper(MPEnvWrapper):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def start_pos(self):
|
def start_pos(self):
|
||||||
return self._start_pos
|
return self.env.start_pos
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def goal_pos(self):
|
def goal_pos(self):
|
||||||
|
@ -10,7 +10,7 @@ from alr_envs.classic_control.utils import check_self_collision
|
|||||||
|
|
||||||
class ViaPointReacher(gym.Env):
|
class ViaPointReacher(gym.Env):
|
||||||
|
|
||||||
def __init__(self, n_links, random_start: bool = True, via_target: Union[None, Iterable] = None,
|
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
|
||||||
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
|
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
|
||||||
|
|
||||||
self.n_links = n_links
|
self.n_links = n_links
|
||||||
@ -19,8 +19,8 @@ class ViaPointReacher(gym.Env):
|
|||||||
self.random_start = random_start
|
self.random_start = random_start
|
||||||
|
|
||||||
# provided initial parameters
|
# provided initial parameters
|
||||||
self.target = target # provided target value
|
self.intitial_target = target # provided target value
|
||||||
self.via_target = via_target # provided via point target value
|
self.initial_via_target = via_target # provided via point target value
|
||||||
|
|
||||||
# temp container for current env state
|
# temp container for current env state
|
||||||
self._via_point = np.ones(2)
|
self._via_point = np.ones(2)
|
||||||
@ -63,6 +63,10 @@ class ViaPointReacher(gym.Env):
|
|||||||
def dt(self):
|
def dt(self):
|
||||||
return self._dt
|
return self._dt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_pos(self):
|
||||||
|
return self._start_pos
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
"""
|
"""
|
||||||
a single step with an action in joint velocity space
|
a single step with an action in joint velocity space
|
||||||
@ -107,22 +111,22 @@ class ViaPointReacher(gym.Env):
|
|||||||
total_length = np.sum(self.link_lengths)
|
total_length = np.sum(self.link_lengths)
|
||||||
|
|
||||||
# rejection sampled point in inner circle with 0.5*Radius
|
# rejection sampled point in inner circle with 0.5*Radius
|
||||||
if self.via_target is None:
|
if self.initial_via_target is None:
|
||||||
via_target = np.array([total_length, total_length])
|
via_target = np.array([total_length, total_length])
|
||||||
while np.linalg.norm(via_target) >= 0.5 * total_length:
|
while np.linalg.norm(via_target) >= 0.5 * total_length:
|
||||||
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
|
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
|
||||||
else:
|
else:
|
||||||
via_target = np.copy(self.via_target)
|
via_target = np.copy(self.initial_via_target)
|
||||||
|
|
||||||
# rejection sampled point in outer circle
|
# rejection sampled point in outer circle
|
||||||
if self.target is None:
|
if self.intitial_target is None:
|
||||||
goal = np.array([total_length, total_length])
|
goal = np.array([total_length, total_length])
|
||||||
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
|
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
|
||||||
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
|
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
|
||||||
else:
|
else:
|
||||||
goal = np.copy(self.target)
|
goal = np.copy(self.intitial_target)
|
||||||
|
|
||||||
self.via_target = via_target
|
self._via_point = via_target
|
||||||
self._goal = goal
|
self._goal = goal
|
||||||
|
|
||||||
def _update_joints(self):
|
def _update_joints(self):
|
||||||
|
@ -12,18 +12,19 @@ class ViaPointReacherMPWrapper(MPEnvWrapper):
|
|||||||
[self.env.random_start] * self.env.n_links, # cos
|
[self.env.random_start] * self.env.n_links, # cos
|
||||||
[self.env.random_start] * self.env.n_links, # sin
|
[self.env.random_start] * self.env.n_links, # sin
|
||||||
[self.env.random_start] * self.env.n_links, # velocity
|
[self.env.random_start] * self.env.n_links, # velocity
|
||||||
[self.env.via_target is None] * 2, # x-y coordinates of via point distance
|
[self.env.initial_via_target is None] * 2, # x-y coordinates of via point distance
|
||||||
[True] * 2, # x-y coordinates of target distance
|
[True] * 2, # x-y coordinates of target distance
|
||||||
[False] # env steps
|
[False] # env steps
|
||||||
])
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_pos(self) -> Union[float, int, np.ndarray]:
|
def start_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
return self._start_pos
|
return self.env.start_pos
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
def dt(self) -> Union[float, int]:
|
def dt(self) -> Union[float, int]:
|
||||||
return self.env.dt
|
return self.env.dt
|
||||||
|
16
example.py
16
example.py
@ -1,7 +1,9 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from alr_envs.utils.mp_env_async_sampler import AlrMpEnvSampler, AlrContextualMpEnvSampler, DummyDist
|
|
||||||
|
from alr_envs.utils.mp_env_async_sampler import AlrContextualMpEnvSampler, AlrMpEnvSampler, DummyDist
|
||||||
|
|
||||||
|
|
||||||
def example_mujoco():
|
def example_mujoco():
|
||||||
@ -14,8 +16,8 @@ def example_mujoco():
|
|||||||
obs, reward, done, info = env.step(env.action_space.sample())
|
obs, reward, done, info = env.step(env.action_space.sample())
|
||||||
rewards += reward
|
rewards += reward
|
||||||
|
|
||||||
if i % 1 == 0:
|
# if i % 1 == 0:
|
||||||
env.render()
|
# env.render()
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
print(rewards)
|
print(rewards)
|
||||||
@ -23,8 +25,7 @@ def example_mujoco():
|
|||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
|
||||||
def example_mp(env_name="alr_envs:HoleReacherDMP-v0"):
|
def example_mp(env_name="alr_envs:HoleReacherDMP-v1"):
|
||||||
# env = gym.make("alr_envs:ViaPointReacherDMP-v0")
|
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
# env.render(mode=None)
|
# env.render(mode=None)
|
||||||
@ -105,11 +106,12 @@ def example_async_contextual_sampler(env_name="alr_envs:SimpleReacherDMP-v1", n_
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
example_mp("alr_envs:HoleReacherDetPMP-v0")
|
||||||
# example_mujoco()
|
# example_mujoco()
|
||||||
# example_mp("alr_envs:SimpleReacherDMP-v1")
|
# example_mp("alr_envs:SimpleReacherDMP-v1")
|
||||||
# example_async("alr_envs:LongSimpleReacherDMP-v0", 4)
|
# example_async("alr_envs:LongSimpleReacherDMP-v0", 4)
|
||||||
# example_async_contextual_sampler()
|
# example_async_contextual_sampler()
|
||||||
# env = gym.make("alr_envs:HoleReacherDetPMP-v1")
|
# env = gym.make("alr_envs:HoleReacherDetPMP-v1")
|
||||||
env_name = "alr_envs:ALRBallInACupPDSimpleDetPMP-v0"
|
# env_name = "alr_envs:ALRBallInACupPDSimpleDetPMP-v0"
|
||||||
example_async_sampler(env_name)
|
# example_async_sampler(env_name)
|
||||||
# example_mp(env_name)
|
# example_mp(env_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user