added DMC Reacher, cartpole, reach_size; removed BBO
This commit is contained in:
parent
7e2f5d664b
commit
24e73d5098
@ -48,12 +48,6 @@ All environments provide the full episode reward and additional information abou
|
||||
|
||||
[//]: |`HoleReacherDetPMP-v0`|
|
||||
|
||||
### Stochastic Search
|
||||
|Name| Description|Horizon|Action Dimension|Observation Dimension
|
||||
|---|---|---|---|---|
|
||||
|`Rosenbrock{dim}-v0`| Gym interface for Rosenbrock function. `{dim}` is one of 5, 10, 25, 50 or 100. | 1 | `{dim}` | 0
|
||||
|
||||
|
||||
## Install
|
||||
1. Clone the repository
|
||||
```bash
|
||||
|
@ -4,9 +4,11 @@ from gym.envs.registration import register
|
||||
from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper
|
||||
from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper
|
||||
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper
|
||||
from alr_envs.dmc.ball_in_cup.ball_in_the_cup_mp_wrapper import DMCBallInCupMPWrapper
|
||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_mp_wrapper import BallInACupMPWrapper
|
||||
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
||||
from alr_envs.dmc.manipulation.reach.reach_mp_wrapper import DMCReachSiteMPWrapper
|
||||
from alr_envs.dmc.suite.ball_in_cup.ball_in_cup_mp_wrapper import DMCBallInCupMPWrapper
|
||||
from alr_envs.dmc.suite.cartpole.cartpole_mp_wrapper import DMCCartpoleMPWrapper, DMCCartpoleThreePolesMPWrapper, \
|
||||
DMCCartpoleTwoPolesMPWrapper
|
||||
from alr_envs.dmc.suite.reacher.reacher_mp_wrapper import DMCReacherMPWrapper
|
||||
|
||||
# Mujoco
|
||||
|
||||
@ -88,54 +90,6 @@ register(
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupSimple-v0',
|
||||
entry_point='alr_envs.mujoco:ALRBallInACupEnv',
|
||||
max_episode_steps=4000,
|
||||
kwargs={
|
||||
"simplified": True,
|
||||
"reward_type": "no_context",
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupPDSimple-v0',
|
||||
entry_point='alr_envs.mujoco:ALRBallInACupPDEnv',
|
||||
max_episode_steps=4000,
|
||||
kwargs={
|
||||
"simplified": True,
|
||||
"reward_type": "no_context"
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupPD-v0',
|
||||
entry_point='alr_envs.mujoco:ALRBallInACupPDEnv',
|
||||
max_episode_steps=4000,
|
||||
kwargs={
|
||||
"simplified": False,
|
||||
"reward_type": "no_context"
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACup-v0',
|
||||
entry_point='alr_envs.mujoco:ALRBallInACupEnv',
|
||||
max_episode_steps=4000,
|
||||
kwargs={
|
||||
"reward_type": "no_context"
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupGoal-v0',
|
||||
entry_point='alr_envs.mujoco:ALRBallInACupEnv',
|
||||
max_episode_steps=4000,
|
||||
kwargs={
|
||||
"reward_type": "contextual_goal"
|
||||
}
|
||||
)
|
||||
|
||||
# Classic control
|
||||
|
||||
## Simple Reacher
|
||||
@ -239,7 +193,7 @@ register(
|
||||
}
|
||||
)
|
||||
|
||||
# MP environments
|
||||
# Motion Primitive Environments
|
||||
|
||||
## Simple Reacher
|
||||
versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
|
||||
@ -255,7 +209,7 @@ for v in versions:
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2 if "long" not in v.lower() else 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"duration": 20,
|
||||
"alpha_phase": 2,
|
||||
"learn_goal": True,
|
||||
"policy_type": "velocity",
|
||||
@ -325,200 +279,14 @@ for v in versions:
|
||||
}
|
||||
)
|
||||
|
||||
# TODO: properly add final_pos
|
||||
register(
|
||||
id='HoleReacherFixedGoalDMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": "alr_envs:HoleReacher-v0",
|
||||
"wrappers": [HoleReacherMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"learn_goal": False,
|
||||
"alpha_phase": 2,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
## Ball in Cup
|
||||
|
||||
register(
|
||||
id='ALRBallInACupSimpleDMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||
"wrappers": [BallInACupMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 3,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
"post_traj_time": 4.5,
|
||||
"learn_goal": False,
|
||||
"alpha_phase": 3,
|
||||
"bandwidth_factor": 2.5,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 100,
|
||||
"return_to_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupDMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACup-v0",
|
||||
"wrappers": [BallInACupMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 7,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
"post_traj_time": 4.5,
|
||||
"learn_goal": False,
|
||||
"alpha_phase": 3,
|
||||
"bandwidth_factor": 2.5,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 100,
|
||||
"return_to_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupSimpleDetPMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||
"wrappers": [BallInACupMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 3,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
"post_traj_time": 4.5,
|
||||
"width": 0.0035,
|
||||
# "off": -0.05,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"zero_goal": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupPDSimpleDetPMP-v0',
|
||||
entry_point='alr_envs.mujoco.ball_in_a_cup.biac_pd:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupPDSimple-v0",
|
||||
"wrappers": [BallInACupMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 3,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
"post_traj_time": 4.5,
|
||||
"width": 0.0035,
|
||||
# "off": -0.05,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"zero_goal": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupPDDetPMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupPD-v0",
|
||||
"num_dof": 7,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
"post_traj_time": 4.5,
|
||||
"width": 0.0035,
|
||||
# "off": -0.05,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"zero_goal": True,
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupDetPMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||
"wrappers": [BallInACupMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 7,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
"post_traj_time": 4.5,
|
||||
"width": 0.0035,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"zero_goal": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupGoalDMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_contextual_env',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupGoal-v0",
|
||||
"wrappers": [BallInACupMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 7,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
"post_traj_time": 4.5,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 3,
|
||||
"bandwidth_factor": 2.5,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
## DMC
|
||||
## Deep Mind Control Suite (DMC)
|
||||
### Suite
|
||||
|
||||
# tasks = ["ball_in_cup-catch", "reacher-easy", "reacher-hard", "cartpole-balance", "cartpole-balance_sparse",
|
||||
# "cartpole-swingup", "cartpole-swingup_sparse", "cartpole-two_poles", "cartpole-three_poles"]
|
||||
# wrappers = [DMCBallInCupMPWrapper, DMCReacherMPWrapper, DMCReacherMPWrapper, DMCCartpoleMPWrapper,
|
||||
# partial(DMCCartpoleMPWrapper)]
|
||||
# for t, w in zip(tasks, wrappers):
|
||||
register(
|
||||
id=f'dmc_ball_in_cup-catch_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
@ -570,14 +338,455 @@ register(
|
||||
}
|
||||
)
|
||||
|
||||
# BBO functions
|
||||
|
||||
for dim in [5, 10, 25, 50, 100]:
|
||||
register(
|
||||
id=f'Rosenbrock{dim}-v0',
|
||||
entry_point='alr_envs.stochastic_search:StochasticSearchEnv',
|
||||
max_episode_steps=1,
|
||||
# TODO tune gains and episode length for all below
|
||||
register(
|
||||
id=f'dmc_reacher-easy_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"cost_f": Rosenbrock(dim),
|
||||
"name": f"reacher-easy",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCReacherMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 1,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_reacher-easy_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"reacher-easy",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCReacherMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 1,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_reacher-hard_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"reacher-hard",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCReacherMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 1,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_reacher-hard_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"reacher-hard",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCReacherMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 1,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
register(
|
||||
id=f'dmc_cartpole-balance_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"cartpole-balance",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_cartpole-balance_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"cartpole-balance",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
register(
|
||||
id=f'dmc_cartpole-balance_sparse_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"cartpole-balance_sparse",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_cartpole-balance_sparse_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"cartpole-balance_sparse",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_cartpole-swingup_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"cartpole-swingup",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_cartpole-swingup_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"cartpole-swingup",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
register(
|
||||
id=f'dmc_cartpole-swingup_sparse_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"cartpole-swingup_sparse",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_cartpole-swingup_sparse_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"cartpole-swingup_sparse",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [DMCCartpoleMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
register(
|
||||
id=f'dmc_cartpole-two_poles_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"cartpole-two_poles",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)],
|
||||
"wrappers": [DMCCartpoleTwoPolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_cartpole-two_poles_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"cartpole-two_poles",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)],
|
||||
"wrappers": [DMCCartpoleTwoPolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
register(
|
||||
id=f'dmc_cartpole-three_poles_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"cartpole-three_poles",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)],
|
||||
"wrappers": [DMCCartpoleThreePolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_cartpole-three_poles_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"cartpole-three_poles",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 1000,
|
||||
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)],
|
||||
"wrappers": [DMCCartpoleThreePolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": 50,
|
||||
"d_gains": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
### Manipulation
|
||||
|
||||
register(
|
||||
id=f'dmc_manipulation-reach_site_dmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"manipulation-reach_site_features",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 250,
|
||||
"wrappers": [DMCReachSiteMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 9,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'dmc_manipulation-reach_site_detpmp-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"manipulation-reach_site_features",
|
||||
# "time_limit": 1,
|
||||
"episode_length": 250,
|
||||
"wrappers": [DMCReachSiteMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 9,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
"width": 0.025,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
@ -128,7 +128,7 @@ class HoleReacherEnv(gym.Env):
|
||||
width = np.copy(self.initial_width)
|
||||
if self.initial_x is None:
|
||||
# sample whole on left or right side
|
||||
direction = np.random.choice([-1, 1])
|
||||
direction = self.np_random.choice([-1, 1])
|
||||
# Hole center needs to be half the width away from the arm to give a valid setting.
|
||||
x = direction * self.np_random.uniform(width / 2, 3.5)
|
||||
else:
|
||||
@ -263,7 +263,7 @@ class HoleReacherEnv(gym.Env):
|
||||
self.fig.show()
|
||||
|
||||
self.fig.gca().set_title(
|
||||
f"Iteration: {self._steps}, distance: {self.end_effector - self._goal}")
|
||||
f"Iteration: {self._steps}, distance: {np.linalg.norm(self.end_effector - self._goal) ** 2}")
|
||||
|
||||
if mode == "human":
|
||||
|
||||
|
38
alr_envs/dmc/manipulation/reach/reach_mp_wrapper.py
Normal file
38
alr_envs/dmc/manipulation/reach/reach_mp_wrapper.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
|
||||
|
||||
class DMCReachSiteMPWrapper(MPEnvWrapper):
|
||||
|
||||
@property
|
||||
def active_obs(self):
|
||||
# Joint and target positions are randomized, velocities are always set to 0.
|
||||
return np.hstack([
|
||||
[True] * 3, # target position
|
||||
[True] * 12, # sin/cos arm joint position
|
||||
[True] * 6, # arm joint torques
|
||||
[False] * 6, # arm joint velocities
|
||||
[True] * 3, # sin/cos hand joint position
|
||||
[False] * 3, # hand joint velocities
|
||||
[True] * 3, # hand pinch site position
|
||||
[True] * 9, # pinch site rmat
|
||||
])
|
||||
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||
return self.env.physics.named.data.qpos[:]
|
||||
|
||||
@property
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return self.env.physics.named.data.qvel[:]
|
||||
|
||||
@property
|
||||
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
0
alr_envs/dmc/suite/__init__.py
Normal file
0
alr_envs/dmc/suite/__init__.py
Normal file
0
alr_envs/dmc/suite/ball_in_cup/__init__.py
Normal file
0
alr_envs/dmc/suite/ball_in_cup/__init__.py
Normal file
@ -19,11 +19,11 @@ class DMCBallInCupMPWrapper(MPEnvWrapper):
|
||||
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||
return np.hstack([self.physics.named.data.qpos['cup_x'], self.physics.named.data.qpos['cup_z']])
|
||||
return np.hstack([self.env.physics.named.data.qpos['cup_x'], self.env.physics.named.data.qpos['cup_z']])
|
||||
|
||||
@property
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return np.hstack([self.physics.named.data.qvel['cup_x'], self.physics.named.data.qvel['cup_z']])
|
||||
return np.hstack([self.env.physics.named.data.qvel['cup_x'], self.env.physics.named.data.qvel['cup_z']])
|
||||
|
||||
@property
|
||||
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
0
alr_envs/dmc/suite/reacher/__init__.py
Normal file
0
alr_envs/dmc/suite/reacher/__init__.py
Normal file
33
alr_envs/dmc/suite/reacher/reacher_mp_wrapper.py
Normal file
33
alr_envs/dmc/suite/reacher/reacher_mp_wrapper.py
Normal file
@ -0,0 +1,33 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||
|
||||
|
||||
class DMCReacherMPWrapper(MPEnvWrapper):
|
||||
|
||||
@property
|
||||
def active_obs(self):
|
||||
# Joint and target positions are randomized, velocities are always set to 0.
|
||||
return np.hstack([
|
||||
[True] * 2, # joint position
|
||||
[True] * 2, # target position
|
||||
[False] * 2, # joint velocity
|
||||
])
|
||||
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||
return self.env.physics.named.data.qpos[:]
|
||||
|
||||
@property
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return self.env.physics.named.data.qvel[:]
|
||||
|
||||
@property
|
||||
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
@ -1,4 +1,4 @@
|
||||
from alr_envs.dmc.ball_in_cup.ball_in_the_cup_mp_wrapper import DMCBallInCupMPWrapper
|
||||
from alr_envs.dmc.suite.ball_in_cup.ball_in_cup_mp_wrapper import DMCBallInCupMPWrapper
|
||||
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env
|
||||
|
||||
|
||||
@ -123,11 +123,11 @@ if __name__ == '__main__':
|
||||
render = False
|
||||
|
||||
# # Standard DMC Suite tasks
|
||||
# example_dmc("fish-swim", seed=10, iterations=1000, render=render)
|
||||
#
|
||||
# # Manipulation tasks
|
||||
# # Disclaimer: The vision versions are currently not integrated and yield an error
|
||||
# example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render)
|
||||
example_dmc("fish-swim", seed=10, iterations=1000, render=render)
|
||||
|
||||
# Manipulation tasks
|
||||
# Disclaimer: The vision versions are currently not integrated and yield an error
|
||||
example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render)
|
||||
|
||||
# Gym + DMC hybrid task provided in the MP framework
|
||||
example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=render)
|
||||
|
@ -86,17 +86,18 @@ def example_async(env_id="alr_envs:HoleReacher-v0", n_cpu=4, seed=int('533D', 16
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = False
|
||||
# Basic gym task
|
||||
example_general("Pendulum-v0", seed=10, iterations=200, render=True)
|
||||
example_general("Pendulum-v0", seed=10, iterations=200, render=render)
|
||||
#
|
||||
# # Basis task from framework
|
||||
example_general("alr_envs:HoleReacher-v0", seed=10, iterations=200, render=True)
|
||||
example_general("alr_envs:HoleReacher-v0", seed=10, iterations=200, render=render)
|
||||
#
|
||||
# # OpenAI Mujoco task
|
||||
example_general("HalfCheetah-v2", seed=10, render=True)
|
||||
example_general("HalfCheetah-v2", seed=10, render=render)
|
||||
#
|
||||
# # Mujoco task from framework
|
||||
example_general("alr_envs:ALRReacher-v0", seed=10, iterations=200, render=True)
|
||||
example_general("alr_envs:ALRReacher-v0", seed=10, iterations=200, render=render)
|
||||
|
||||
# Vectorized multiprocessing environments
|
||||
example_async(env_id="alr_envs:HoleReacher-v0", n_cpu=2, seed=int('533D', 16), n_samples=2 * 200)
|
||||
|
@ -148,14 +148,15 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = False
|
||||
# DMP
|
||||
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=True)
|
||||
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||
|
||||
# DetProMP
|
||||
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=True)
|
||||
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
|
||||
|
||||
# Altered basis functions
|
||||
example_custom_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=True)
|
||||
example_custom_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||
|
||||
# Custom MP
|
||||
example_fully_custom_mp(seed=10, iterations=1, render=True)
|
||||
example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from alr_envs.mujoco.reacher.alr_reacher import ALRReacherEnv
|
||||
from alr_envs.mujoco.balancing import BalancingEnv
|
||||
# from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
||||
from alr_envs.mujoco.reacher.balancing import BalancingEnv
|
||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
||||
from alr_envs.mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
||||
|
@ -1 +0,0 @@
|
||||
from alr_envs.stochastic_search.stochastic_search import StochasticSearchEnv
|
@ -1,76 +0,0 @@
|
||||
import numpy as np
|
||||
import scipy.stats as scistats
|
||||
|
||||
np.seterr(divide='ignore', invalid='ignore')
|
||||
|
||||
|
||||
class BaseObjective(object):
|
||||
def __init__(self, dim, int_opt=None, val_opt=None, alpha=None, beta=None):
|
||||
self.dim = dim
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
# check if optimal parameter is in interval...
|
||||
if int_opt is not None:
|
||||
self.x_opt = np.random.uniform(int_opt[0], int_opt[1], size=(1, dim))
|
||||
# ... or based on a single value
|
||||
elif val_opt is not None:
|
||||
self.one_pm = np.where(np.random.rand(1, dim) > 0.5, 1, -1)
|
||||
self.x_opt = val_opt * self.one_pm
|
||||
else:
|
||||
raise ValueError("Optimal value or interval has to be defined")
|
||||
self.f_opt = np.round(np.clip(scistats.cauchy.rvs(loc=0, scale=100, size=1)[0], -1000, 1000), decimals=2)
|
||||
self.i = np.arange(self.dim)
|
||||
self._lambda_alpha = None
|
||||
self._q = None
|
||||
self._r = None
|
||||
|
||||
def __call__(self, x):
|
||||
return self.evaluate_full(x)
|
||||
|
||||
def evaluate_full(self, x):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
def gs(self):
|
||||
# Gram Schmidt ortho-normalization
|
||||
a = np.random.randn(self.dim, self.dim)
|
||||
b, _ = np.linalg.qr(a)
|
||||
return b
|
||||
|
||||
# TODO: property probably unnecessary
|
||||
@property
|
||||
def q(self):
|
||||
if self._q is None:
|
||||
self._q = self.gs()
|
||||
return self._q
|
||||
|
||||
@property
|
||||
def r(self):
|
||||
if self._r is None:
|
||||
self._r = self.gs()
|
||||
return self._r
|
||||
|
||||
@property
|
||||
def lambda_alpha(self):
|
||||
if self._lambda_alpha is None:
|
||||
if isinstance(self.alpha, int):
|
||||
lambda_ii = np.power(self.alpha, 1 / 2 * self.i / (self.dim - 1))
|
||||
self._lambda_alpha = np.diag(lambda_ii)
|
||||
else:
|
||||
lambda_ii = np.power(self.alpha[:, None], 1 / 2 * self.i[None, :] / (self.dim - 1))
|
||||
self._lambda_alpha = np.stack([np.diag(l_ii) for l_ii in lambda_ii])
|
||||
return self._lambda_alpha
|
||||
|
||||
@staticmethod
|
||||
def f_pen(x):
|
||||
return np.sum(np.maximum(0, np.abs(x) - 5), axis=1)
|
||||
|
||||
def t_asy_beta(self, x):
|
||||
# exp = np.power(x, 1 + self.beta * self.i[:, None] / (self.input_dim - 1) * np.sqrt(x))
|
||||
# return np.where(x > 0, exp, x)
|
||||
return x
|
||||
|
||||
def t_osz(self, x):
|
||||
x_hat = np.where(x != 0, np.log(np.abs(x)), 0)
|
||||
c_1 = np.where(x > 0, 10, 5.5)
|
||||
c_2 = np.where(x > 0, 7.9, 3.1)
|
||||
return np.sign(x) * np.exp(x_hat + 0.049 * (np.sin(c_1 * x_hat) + np.sin(c_2 * x_hat)))
|
@ -1,56 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
from alr_envs.stochastic_search.functions.f_base import BaseObjective
|
||||
|
||||
|
||||
class Rosenbrock(BaseObjective):
|
||||
def __init__(self, dim, int_opt=(-3., 3.)):
|
||||
super(Rosenbrock, self).__init__(dim, int_opt=int_opt)
|
||||
self.c = np.maximum(1, np.sqrt(self.dim) / 8)
|
||||
|
||||
def evaluate_full(self, x):
|
||||
x = np.atleast_2d(x)
|
||||
assert x.shape[1] == self.dim
|
||||
|
||||
z = self.c * (x - self.x_opt) + 1
|
||||
z_end = z[:, 1:]
|
||||
z_begin = z[:, :-1]
|
||||
|
||||
a = z_begin ** 2 - z_end
|
||||
b = z_begin - 1
|
||||
|
||||
return np.sum(100 * a ** 2 + b ** 2, axis=1) + self.f_opt
|
||||
|
||||
|
||||
class RosenbrockRotated(BaseObjective):
|
||||
def __init__(self, dim, int_opt=(-3., 3.)):
|
||||
super(RosenbrockRotated, self).__init__(dim, int_opt=int_opt)
|
||||
self.c = np.maximum(1, np.sqrt(self.dim) / 8)
|
||||
|
||||
def evaluate_full(self, x):
|
||||
x = np.atleast_2d(x)
|
||||
assert x.shape[1] == self.dim
|
||||
|
||||
z = (self.c * self.r @ x.T + 1 / 2).T
|
||||
a = z[:, :-1] ** 2 - z[:, 1:]
|
||||
b = z[:, :-1] - 1
|
||||
|
||||
return np.sum(100 * a ** 2 + b ** 2, axis=1) + self.f_opt
|
||||
|
||||
|
||||
class RosenbrockRaw(BaseObjective):
|
||||
def __init__(self, dim, int_opt=(-3., 3.)):
|
||||
super(RosenbrockRaw, self).__init__(dim, int_opt=int_opt)
|
||||
self.x_opt = np.ones((1, dim))
|
||||
self.f_opt = 0
|
||||
|
||||
def evaluate_full(self, x):
|
||||
x = np.atleast_2d(x)
|
||||
assert x.shape[1] == self.dim
|
||||
|
||||
a = x[:, :-1] ** 2 - x[:, 1:]
|
||||
b = x[:, :-1] - 1
|
||||
|
||||
out = np.sum(100 * a ** 2 + b ** 2, axis=1)
|
||||
|
||||
return out
|
@ -1,22 +0,0 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from alr_envs.stochastic_search.functions.f_base import BaseObjective
|
||||
|
||||
|
||||
class StochasticSearchEnv(gym.Env):
|
||||
|
||||
def __init__(self, cost_f: BaseObjective):
|
||||
self.cost_f = cost_f
|
||||
|
||||
self.action_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.cost_f.dim,), dtype=np.float64)
|
||||
self.observation_space = gym.spaces.Box(low=(), high=(), shape=(), dtype=np.float64)
|
||||
|
||||
def step(self, action):
|
||||
return np.zeros(self.observation_space.shape), np.squeeze(-self.cost_f(action)), True, {}
|
||||
|
||||
def reset(self):
|
||||
return np.zeros(self.observation_space.shape)
|
||||
|
||||
def render(self, mode='human'):
|
||||
pass
|
@ -49,6 +49,8 @@ def make_env(env_id: str, seed, **kwargs):
|
||||
# Gym
|
||||
env = gym.make(env_id, **kwargs)
|
||||
env.seed(seed)
|
||||
env.action_space.seed(seed)
|
||||
env.observation_space.seed(seed)
|
||||
except gym.error.Error:
|
||||
# DMC
|
||||
from alr_envs.utils import make
|
||||
@ -79,7 +81,7 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1
|
||||
_env = make_env(env_id, seed, **kwargs)
|
||||
|
||||
assert any(issubclass(w, MPEnvWrapper) for w in wrappers), \
|
||||
"At least an MPEnvWrapper is required in order to leverage motion primitive environments."
|
||||
"At least one MPEnvWrapper is required in order to leverage motion primitive environments."
|
||||
for w in wrappers:
|
||||
_env = w(_env)
|
||||
|
||||
|
3
setup.py
3
setup.py
@ -17,5 +17,6 @@ setup(
|
||||
license='MIT',
|
||||
author='Fabian Otto, Marcel Sandermann, Maximilian Huettenrauch',
|
||||
author_email='',
|
||||
description='Custom Gym environments for various (robotics) simple_reacher.'
|
||||
description='Custom Gym environments for various (robotics) tasks. integration of DMC environments into the'
|
||||
'gym interface, and support for using motion primitives with gym environments.'
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user