Ported classic_control envs to fancy registry

This commit is contained in:
Dominik Moritz Roth 2023-07-14 14:31:36 +02:00
parent 6c90f8ade2
commit f375a6e4df
4 changed files with 152 additions and 230 deletions

View File

@ -1,7 +1,8 @@
from copy import deepcopy
import numpy as np
from gymnasium import register
from gymnasium import register as gym_register
from .registry import register, ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS
from . import classic_control, mujoco
from .classic_control.hole_reacher.hole_reacher import HoleReacherEnv
@ -21,80 +22,8 @@ from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTempo
from .mujoco.table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching, \
MAX_EPISODE_STEPS_TABLE_TENNIS
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
DEFAULT_BB_DICT_ProMP = {
"name": 'EnvName',
"wrappers": [],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'promp'
},
"phase_generator_kwargs": {
'phase_generator_type': 'linear'
},
"controller_kwargs": {
'controller_type': 'motor',
"p_gains": 1.0,
"d_gains": 0.1,
},
"basis_generator_kwargs": {
'basis_generator_type': 'zero_rbf',
'num_basis': 5,
'num_basis_zero_start': 1,
'basis_bandwidth_factor': 3.0,
},
"black_box_kwargs": {
}
}
DEFAULT_BB_DICT_DMP = {
"name": 'EnvName',
"wrappers": [],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'dmp'
},
"phase_generator_kwargs": {
'phase_generator_type': 'exp'
},
"controller_kwargs": {
'controller_type': 'motor',
"p_gains": 1.0,
"d_gains": 0.1,
},
"basis_generator_kwargs": {
'basis_generator_type': 'rbf',
'num_basis': 5
}
}
DEFAULT_BB_DICT_ProDMP = {
"name": 'EnvName',
"wrappers": [],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'prodmp',
'duration': 2.0,
'weights_scale': 1.0,
},
"phase_generator_kwargs": {
'phase_generator_type': 'exp',
'tau': 1.5,
},
"controller_kwargs": {
'controller_type': 'motor',
"p_gains": 1.0,
"d_gains": 0.1,
},
"basis_generator_kwargs": {
'basis_generator_type': 'prodmp',
'alpha': 10,
'num_basis': 5,
},
"black_box_kwargs": {
}
}
# Classic Control
## Simple Reacher
# Simple Reacher
register(
id='SimpleReacher-v0',
entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv',
@ -113,8 +42,7 @@ register(
}
)
## Viapoint Reacher
# Viapoint Reacher
register(
id='ViaPointReacher-v0',
entry_point='fancy_gym.envs.classic_control:ViaPointReacherEnv',
@ -126,7 +54,7 @@ register(
}
)
## Hole Reacher
# Hole Reacher
register(
id='HoleReacher-v0',
entry_point='fancy_gym.envs.classic_control:HoleReacherEnv',
@ -145,9 +73,9 @@ register(
# Mujoco
## Mujoco Reacher
# Mujoco Reacher
for _dims in [5, 7]:
register(
gym_register(
id=f'Reacher{_dims}d-v0',
entry_point='fancy_gym.envs.mujoco:ReacherEnv',
max_episode_steps=MAX_EPISODE_STEPS_REACHER,
@ -156,7 +84,7 @@ for _dims in [5, 7]:
}
)
register(
gym_register(
id=f'Reacher{_dims}dSparse-v0',
entry_point='fancy_gym.envs.mujoco:ReacherEnv',
max_episode_steps=MAX_EPISODE_STEPS_REACHER,
@ -167,7 +95,7 @@ for _dims in [5, 7]:
}
)
register(
gym_register(
id='HopperJumpSparse-v0',
entry_point='fancy_gym.envs.mujoco:HopperJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
@ -176,7 +104,7 @@ register(
}
)
register(
gym_register(
id='HopperJump-v0',
entry_point='fancy_gym.envs.mujoco:HopperJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
@ -188,43 +116,43 @@ register(
}
)
register(
gym_register(
id='AntJump-v0',
entry_point='fancy_gym.envs.mujoco:AntJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
)
register(
gym_register(
id='HalfCheetahJump-v0',
entry_point='fancy_gym.envs.mujoco:HalfCheetahJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
)
register(
gym_register(
id='HopperJumpOnBox-v0',
entry_point='fancy_gym.envs.mujoco:HopperJumpOnBoxEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
)
register(
gym_register(
id='HopperThrow-v0',
entry_point='fancy_gym.envs.mujoco:HopperThrowEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
)
register(
gym_register(
id='HopperThrowInBasket-v0',
entry_point='fancy_gym.envs.mujoco:HopperThrowInBasketEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
)
register(
gym_register(
id='Walker2DJump-v0',
entry_point='fancy_gym.envs.mujoco:Walker2dJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
)
register(
gym_register(
id='BeerPong-v0',
entry_point='fancy_gym.envs.mujoco:BeerPongEnv',
max_episode_steps=MAX_EPISODE_STEPS_BEERPONG,
@ -232,7 +160,7 @@ register(
# Box pushing environments with different rewards
for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]:
register(
gym_register(
id='BoxPushing{}-v0'.format(reward_type),
entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type),
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
@ -240,7 +168,7 @@ for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]:
# Here we use the same reward as in BeerPong-v0, but now consider after the release,
# only one time step, i.e. we simulate until the end of th episode
register(
gym_register(
id='BeerPongStepBased-v0',
entry_point='fancy_gym.envs.mujoco:BeerPongEnvStepBasedEpisodicReward',
max_episode_steps=FIXED_RELEASE_STEP,
@ -248,7 +176,7 @@ register(
# Table Tennis environments
for ctxt_dim in [2, 4]:
register(
gym_register(
id='TableTennis{}D-v0'.format(ctxt_dim),
entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
@ -258,13 +186,13 @@ for ctxt_dim in [2, 4]:
}
)
register(
gym_register(
id='TableTennisWind-v0',
entry_point='fancy_gym.envs.mujoco:TableTennisWind',
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
)
register(
gym_register(
id='TableTennisGoalSwitching-v0',
entry_point='fancy_gym.envs.mujoco:TableTennisGoalSwitching',
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
@ -276,98 +204,13 @@ register(
# movement Primitive Environments
## Simple Reacher
_versions = ["SimpleReacher-v0", "LongSimpleReacher-v0"]
for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}DMP-{_name[1]}'
kwargs_dict_simple_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_simple_reacher_dmp['wrappers'].append(classic_control.simple_reacher.MPWrapper)
kwargs_dict_simple_reacher_dmp['controller_kwargs']['p_gains'] = 0.6
kwargs_dict_simple_reacher_dmp['controller_kwargs']['d_gains'] = 0.075
kwargs_dict_simple_reacher_dmp['trajectory_generator_kwargs']['weight_scale'] = 50
kwargs_dict_simple_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2
kwargs_dict_simple_reacher_dmp['name'] = f"{_v}"
register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_simple_reacher_dmp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
# Simple Reacher [DONE]
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_simple_reacher_promp['wrappers'].append(classic_control.simple_reacher.MPWrapper)
kwargs_dict_simple_reacher_promp['controller_kwargs']['p_gains'] = 0.6
kwargs_dict_simple_reacher_promp['controller_kwargs']['d_gains'] = 0.075
kwargs_dict_simple_reacher_promp['name'] = _v
register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_simple_reacher_promp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# Viapoint reacher [DONE]
# Viapoint reacher
kwargs_dict_via_point_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_via_point_reacher_dmp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper)
kwargs_dict_via_point_reacher_dmp['controller_kwargs']['controller_type'] = 'velocity'
kwargs_dict_via_point_reacher_dmp['trajectory_generator_kwargs']['weight_scale'] = 50
kwargs_dict_via_point_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2
kwargs_dict_via_point_reacher_dmp['name'] = "ViaPointReacher-v0"
register(
id='ViaPointReacherDMP-v0',
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# max_episode_steps=1,
kwargs=kwargs_dict_via_point_reacher_dmp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0")
# Hole Reacher [DONE]
kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper)
kwargs_dict_via_point_reacher_promp['controller_kwargs']['controller_type'] = 'velocity'
kwargs_dict_via_point_reacher_promp['name'] = "ViaPointReacher-v0"
register(
id="ViaPointReacherProMP-v0",
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_via_point_reacher_promp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
## Hole Reacher
_versions = ["HoleReacher-v0"]
for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}DMP-{_name[1]}'
kwargs_dict_hole_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_hole_reacher_dmp['wrappers'].append(classic_control.hole_reacher.MPWrapper)
kwargs_dict_hole_reacher_dmp['controller_kwargs']['controller_type'] = 'velocity'
# TODO: Before it was weight scale 50 and goal scale 0.1. We now only have weight scale and thus set it to 500. Check
kwargs_dict_hole_reacher_dmp['trajectory_generator_kwargs']['weight_scale'] = 500
kwargs_dict_hole_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2.5
kwargs_dict_hole_reacher_dmp['name'] = _v
register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# max_episode_steps=1,
kwargs=kwargs_dict_hole_reacher_dmp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_hole_reacher_promp['wrappers'].append(classic_control.hole_reacher.MPWrapper)
kwargs_dict_hole_reacher_promp['trajectory_generator_kwargs']['weight_scale'] = 2
kwargs_dict_hole_reacher_promp['controller_kwargs']['controller_type'] = 'velocity'
kwargs_dict_hole_reacher_promp['name'] = f"{_v}"
register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_hole_reacher_promp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
## ReacherNd
# ReacherNd
_versions = ["Reacher5d-v0", "Reacher7d-v0", "Reacher5dSparse-v0", "Reacher7dSparse-v0"]
for _v in _versions:
_name = _v.split("-")
@ -376,7 +219,7 @@ for _v in _versions:
kwargs_dict_reacher_dmp['wrappers'].append(mujoco.reacher.MPWrapper)
kwargs_dict_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2
kwargs_dict_reacher_dmp['name'] = _v
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# max_episode_steps=1,
@ -388,14 +231,14 @@ for _v in _versions:
kwargs_dict_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_reacher_promp['wrappers'].append(mujoco.reacher.MPWrapper)
kwargs_dict_reacher_promp['name'] = _v
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_reacher_promp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
########################################################################################################################
## Beerpong ProMP
# Beerpong ProMP
_versions = ['BeerPong-v0']
for _v in _versions:
_name = _v.split("-")
@ -408,14 +251,14 @@ for _v in _versions:
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_bp_promp['name'] = _v
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_bp_promp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
### BP with Fixed release
# BP with Fixed release
_versions = ["BeerPongStepBased-v0", 'BeerPong-v0']
for _v in _versions:
if _v != 'BeerPong-v0':
@ -431,7 +274,7 @@ for _v in _versions:
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_bp_promp['name'] = _v
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_bp_promp
@ -439,7 +282,7 @@ for _v in _versions:
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
########################################################################################################################
## Table Tennis needs to be fixed according to Zhou's implementation
# Table Tennis needs to be fixed according to Zhou's implementation
# TODO: Add later when finished
# ########################################################################################################################
@ -452,7 +295,7 @@ for _v in _versions:
# kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
# kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper)
# kwargs_dict_ant_jump_promp['name'] = _v
# register(
# gym_register(
# id=_env_id,
# entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# kwargs=kwargs_dict_ant_jump_promp
@ -469,7 +312,7 @@ for _v in _versions:
# kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
# kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper)
# kwargs_dict_halfcheetah_jump_promp['name'] = _v
# register(
# gym_register(
# id=_env_id,
# entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# kwargs=kwargs_dict_halfcheetah_jump_promp
@ -479,7 +322,7 @@ for _v in _versions:
# ########################################################################################################################
## HopperJump
# HopperJump
_versions = ['HopperJump-v0', 'HopperJumpSparse-v0',
# 'HopperJumpOnBox-v0', 'HopperThrow-v0', 'HopperThrowInBasket-v0'
]
@ -490,7 +333,7 @@ for _v in _versions:
kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_hopper_jump_promp['wrappers'].append(mujoco.hopper_jump.MPWrapper)
kwargs_dict_hopper_jump_promp['name'] = _v
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_hopper_jump_promp
@ -499,7 +342,7 @@ for _v in _versions:
# ########################################################################################################################
## Box Pushing
# Box Pushing
_versions = ['BoxPushingDense-v0', 'BoxPushingTemporalSparse-v0', 'BoxPushingTemporalSpatialSparse-v0']
for _v in _versions:
_name = _v.split("-")
@ -511,7 +354,7 @@ for _v in _versions:
kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2 # 3.5, 4 to try
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_box_pushing_promp
@ -535,16 +378,16 @@ for _v in _versions:
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desired'] = True
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_box_pushing_prodmp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
## Table Tennis
# Table Tennis
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0', 'TableTennisWind-v0', 'TableTennisGoalSwitching-v0']
for _v in _versions:
_name = _v.split("-")
@ -565,7 +408,7 @@ for _v in _versions:
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 1
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1
kwargs_dict_tt_promp['black_box_kwargs']['verbose'] = 2
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_tt_promp
@ -595,7 +438,7 @@ for _v in _versions:
kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25.
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_tt_prodmp
@ -624,8 +467,8 @@ for _v in _versions:
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0
register(
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 50 == 0
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_tt_prodmp
@ -640,16 +483,16 @@ for _v in _versions:
# kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
# kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper)
# kwargs_dict_walker2d_jump_promp['name'] = _v
# register(
# gym_register(
# id=_env_id,
# entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# kwargs=kwargs_dict_walker2d_jump_promp
# )
# ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
### Depricated, we will not provide non random starts anymore
# Depricated, we will not provide non random starts anymore
"""
register(
gym_register(
id='SimpleReacher-v1',
entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv',
max_episode_steps=200,
@ -659,7 +502,7 @@ register(
}
)
register(
gym_register(
id='LongSimpleReacher-v1',
entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv',
max_episode_steps=200,
@ -668,7 +511,7 @@ register(
"random_start": False
}
)
register(
gym_register(
id='HoleReacher-v1',
entry_point='fancy_gym.envs.classic_control:HoleReacherEnv',
max_episode_steps=200,
@ -683,7 +526,7 @@ register(
"collision_penalty": 100,
}
)
register(
gym_register(
id='HoleReacher-v2',
entry_point='fancy_gym.envs.classic_control:HoleReacherEnv',
max_episode_steps=200,
@ -700,7 +543,7 @@ register(
)
# CtxtFree are v0, Contextual are v1
register(
gym_register(
id='AntJump-v0',
entry_point='fancy_gym.envs.mujoco:AntJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
@ -710,7 +553,7 @@ register(
}
)
# CtxtFree are v0, Contextual are v1
register(
gym_register(
id='HalfCheetahJump-v0',
entry_point='fancy_gym.envs.mujoco:HalfCheetahJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
@ -719,7 +562,7 @@ register(
"context": False
}
)
register(
gym_register(
id='HopperJump-v0',
entry_point='fancy_gym.envs.mujoco:HopperJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
@ -732,12 +575,12 @@ register(
"""
### Deprecated used for CorL paper
# Deprecated used for CorL paper
"""
_vs = np.arange(101).tolist() + [1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]
for i in _vs:
_env_id = f'ALRReacher{i}-v0'
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.envs.mujoco:ReacherEnv',
max_episode_steps=200,
@ -750,7 +593,7 @@ for i in _vs:
)
_env_id = f'ALRReacherSparse{i}-v0'
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.envs.mujoco:ReacherEnv',
max_episode_steps=200,
@ -764,7 +607,7 @@ for i in _vs:
_vs = np.arange(101).tolist() + [1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]
for i in _vs:
_env_id = f'ALRReacher{i}ProMP-v0'
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_promp_env_helper',
kwargs={
@ -787,7 +630,7 @@ for i in _vs:
)
_env_id = f'ALRReacherSparse{i}ProMP-v0'
register(
gym_register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_promp_env_helper',
kwargs={
@ -809,7 +652,7 @@ for i in _vs:
}
)
register(
gym_register(
id='HopperJumpOnBox-v0',
entry_point='fancy_gym.envs.mujoco:HopperJumpOnBoxEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
@ -818,7 +661,7 @@ for i in _vs:
"context": False
}
)
register(
gym_register(
id='HopperThrow-v0',
entry_point='fancy_gym.envs.mujoco:HopperThrowEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
@ -827,7 +670,7 @@ for i in _vs:
"context": False
}
)
register(
gym_register(
id='HopperThrowInBasket-v0',
entry_point='fancy_gym.envs.mujoco:HopperThrowInBasketEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
@ -836,7 +679,7 @@ for i in _vs:
"context": False
}
)
register(
gym_register(
id='Walker2DJump-v0',
entry_point='fancy_gym.envs.mujoco:Walker2dJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
@ -845,12 +688,12 @@ for i in _vs:
"context": False
}
)
register(id='TableTennis2DCtxt-v1',
gym_register(id='TableTennis2DCtxt-v1',
entry_point='fancy_gym.envs.mujoco:TTEnvGym',
max_episode_steps=MAX_EPISODE_STEPS,
kwargs={'ctxt_dim': 2, 'fixed_goal': True})
register(
gym_register(
id='BeerPong-v0',
entry_point='fancy_gym.envs.mujoco:BeerBongEnv',
max_episode_steps=300,

View File

@ -8,11 +8,41 @@ from gymnasium.core import ObsType
from matplotlib import patches
from fancy_gym.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
from . import MPWrapper
MAX_EPISODE_STEPS_HOLEREACHER = 200
class HoleReacherEnv(BaseReacherDirectEnv):
metadata = {
'mp_config': {
'ProMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'controller_type': 'velocity',
},
'trajectory_generator_kwargs': {
'weight_scale': 2,
},
},
'DMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'controller_type': 'velocity',
},
'trajectory_generator_kwargs': {
# TODO: Before it was weight scale 50 and goal scale 0.1. We now only have weight scale and thus set it to 500. Check
'weight_scale': 500,
},
'phase_generator_kwargs': {
'alpha_phase': 2.5,
},
},
'ProDMP': {},
}
}
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"):

View File

@ -6,6 +6,7 @@ from gymnasium import spaces
from gymnasium.core import ObsType
from fancy_gym.envs.classic_control.base_reacher.base_reacher_torque import BaseReacherTorqueEnv
from . import MPWrapper
class SimpleReacherEnv(BaseReacherTorqueEnv):
@ -15,6 +16,32 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
towards the end of the trajectory.
"""
metadata = {
'mp_config': {
'ProMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'p_gains': 0.6,
'd_gains': 0.075,
},
},
'DMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'p_gains': 0.6,
'd_gains': 0.075,
},
'trajectory_generator_kwargs': {
'weight_scale': 50,
},
'phase_generator_kwargs': {
'alpha_phase': 2,
},
},
'ProDMP': {},
}
}
def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
allow_self_collision: bool = False, ):
super().__init__(n_links, random_start, allow_self_collision)
@ -126,4 +153,3 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
self.fig.canvas.draw()
self.fig.canvas.flush_events()

View File

@ -7,10 +7,35 @@ from gymnasium import spaces
from gymnasium.core import ObsType
from fancy_gym.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
from . import MPWrapper
class ViaPointReacherEnv(BaseReacherDirectEnv):
metadata = {
'mp_config': {
'ProMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'controller_type': 'velocity',
},
},
'DMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'controller_type': 'velocity',
},
'trajectory_generator_kwargs': {
'weight_scale': 50,
},
'phase_generator_kwargs': {
'alpha_phase': 2,
},
},
'ProDMP': {},
}
}
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
@ -184,5 +209,3 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k')
plt.pause(0.01)