Ported classic_control envs to fancy registry

This commit is contained in:
Dominik Moritz Roth 2023-07-14 14:31:36 +02:00
parent 6c90f8ade2
commit f375a6e4df
4 changed files with 152 additions and 230 deletions

View File

@ -1,7 +1,8 @@
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
from gymnasium import register from gymnasium import register as gym_register
from .registry import register, ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS
from . import classic_control, mujoco from . import classic_control, mujoco
from .classic_control.hole_reacher.hole_reacher import HoleReacherEnv from .classic_control.hole_reacher.hole_reacher import HoleReacherEnv
@ -21,80 +22,8 @@ from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTempo
from .mujoco.table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching, \ from .mujoco.table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching, \
MAX_EPISODE_STEPS_TABLE_TENNIS MAX_EPISODE_STEPS_TABLE_TENNIS
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
DEFAULT_BB_DICT_ProMP = {
"name": 'EnvName',
"wrappers": [],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'promp'
},
"phase_generator_kwargs": {
'phase_generator_type': 'linear'
},
"controller_kwargs": {
'controller_type': 'motor',
"p_gains": 1.0,
"d_gains": 0.1,
},
"basis_generator_kwargs": {
'basis_generator_type': 'zero_rbf',
'num_basis': 5,
'num_basis_zero_start': 1,
'basis_bandwidth_factor': 3.0,
},
"black_box_kwargs": {
}
}
DEFAULT_BB_DICT_DMP = {
"name": 'EnvName',
"wrappers": [],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'dmp'
},
"phase_generator_kwargs": {
'phase_generator_type': 'exp'
},
"controller_kwargs": {
'controller_type': 'motor',
"p_gains": 1.0,
"d_gains": 0.1,
},
"basis_generator_kwargs": {
'basis_generator_type': 'rbf',
'num_basis': 5
}
}
DEFAULT_BB_DICT_ProDMP = {
"name": 'EnvName',
"wrappers": [],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'prodmp',
'duration': 2.0,
'weights_scale': 1.0,
},
"phase_generator_kwargs": {
'phase_generator_type': 'exp',
'tau': 1.5,
},
"controller_kwargs": {
'controller_type': 'motor',
"p_gains": 1.0,
"d_gains": 0.1,
},
"basis_generator_kwargs": {
'basis_generator_type': 'prodmp',
'alpha': 10,
'num_basis': 5,
},
"black_box_kwargs": {
}
}
# Classic Control # Classic Control
## Simple Reacher # Simple Reacher
register( register(
id='SimpleReacher-v0', id='SimpleReacher-v0',
entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv', entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv',
@ -113,8 +42,7 @@ register(
} }
) )
## Viapoint Reacher # Viapoint Reacher
register( register(
id='ViaPointReacher-v0', id='ViaPointReacher-v0',
entry_point='fancy_gym.envs.classic_control:ViaPointReacherEnv', entry_point='fancy_gym.envs.classic_control:ViaPointReacherEnv',
@ -126,7 +54,7 @@ register(
} }
) )
## Hole Reacher # Hole Reacher
register( register(
id='HoleReacher-v0', id='HoleReacher-v0',
entry_point='fancy_gym.envs.classic_control:HoleReacherEnv', entry_point='fancy_gym.envs.classic_control:HoleReacherEnv',
@ -145,9 +73,9 @@ register(
# Mujoco # Mujoco
## Mujoco Reacher # Mujoco Reacher
for _dims in [5, 7]: for _dims in [5, 7]:
register( gym_register(
id=f'Reacher{_dims}d-v0', id=f'Reacher{_dims}d-v0',
entry_point='fancy_gym.envs.mujoco:ReacherEnv', entry_point='fancy_gym.envs.mujoco:ReacherEnv',
max_episode_steps=MAX_EPISODE_STEPS_REACHER, max_episode_steps=MAX_EPISODE_STEPS_REACHER,
@ -156,7 +84,7 @@ for _dims in [5, 7]:
} }
) )
register( gym_register(
id=f'Reacher{_dims}dSparse-v0', id=f'Reacher{_dims}dSparse-v0',
entry_point='fancy_gym.envs.mujoco:ReacherEnv', entry_point='fancy_gym.envs.mujoco:ReacherEnv',
max_episode_steps=MAX_EPISODE_STEPS_REACHER, max_episode_steps=MAX_EPISODE_STEPS_REACHER,
@ -167,7 +95,7 @@ for _dims in [5, 7]:
} }
) )
register( gym_register(
id='HopperJumpSparse-v0', id='HopperJumpSparse-v0',
entry_point='fancy_gym.envs.mujoco:HopperJumpEnv', entry_point='fancy_gym.envs.mujoco:HopperJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
@ -176,7 +104,7 @@ register(
} }
) )
register( gym_register(
id='HopperJump-v0', id='HopperJump-v0',
entry_point='fancy_gym.envs.mujoco:HopperJumpEnv', entry_point='fancy_gym.envs.mujoco:HopperJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
@ -188,43 +116,43 @@ register(
} }
) )
register( gym_register(
id='AntJump-v0', id='AntJump-v0',
entry_point='fancy_gym.envs.mujoco:AntJumpEnv', entry_point='fancy_gym.envs.mujoco:AntJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP, max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
) )
register( gym_register(
id='HalfCheetahJump-v0', id='HalfCheetahJump-v0',
entry_point='fancy_gym.envs.mujoco:HalfCheetahJumpEnv', entry_point='fancy_gym.envs.mujoco:HalfCheetahJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP, max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
) )
register( gym_register(
id='HopperJumpOnBox-v0', id='HopperJumpOnBox-v0',
entry_point='fancy_gym.envs.mujoco:HopperJumpOnBoxEnv', entry_point='fancy_gym.envs.mujoco:HopperJumpOnBoxEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
) )
register( gym_register(
id='HopperThrow-v0', id='HopperThrow-v0',
entry_point='fancy_gym.envs.mujoco:HopperThrowEnv', entry_point='fancy_gym.envs.mujoco:HopperThrowEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW, max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
) )
register( gym_register(
id='HopperThrowInBasket-v0', id='HopperThrowInBasket-v0',
entry_point='fancy_gym.envs.mujoco:HopperThrowInBasketEnv', entry_point='fancy_gym.envs.mujoco:HopperThrowInBasketEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET, max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
) )
register( gym_register(
id='Walker2DJump-v0', id='Walker2DJump-v0',
entry_point='fancy_gym.envs.mujoco:Walker2dJumpEnv', entry_point='fancy_gym.envs.mujoco:Walker2dJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP, max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
) )
register( gym_register(
id='BeerPong-v0', id='BeerPong-v0',
entry_point='fancy_gym.envs.mujoco:BeerPongEnv', entry_point='fancy_gym.envs.mujoco:BeerPongEnv',
max_episode_steps=MAX_EPISODE_STEPS_BEERPONG, max_episode_steps=MAX_EPISODE_STEPS_BEERPONG,
@ -232,7 +160,7 @@ register(
# Box pushing environments with different rewards # Box pushing environments with different rewards
for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]: for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]:
register( gym_register(
id='BoxPushing{}-v0'.format(reward_type), id='BoxPushing{}-v0'.format(reward_type),
entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type), entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type),
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING, max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
@ -240,7 +168,7 @@ for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]:
# Here we use the same reward as in BeerPong-v0, but now consider after the release, # Here we use the same reward as in BeerPong-v0, but now consider after the release,
# only one time step, i.e. we simulate until the end of th episode # only one time step, i.e. we simulate until the end of th episode
register( gym_register(
id='BeerPongStepBased-v0', id='BeerPongStepBased-v0',
entry_point='fancy_gym.envs.mujoco:BeerPongEnvStepBasedEpisodicReward', entry_point='fancy_gym.envs.mujoco:BeerPongEnvStepBasedEpisodicReward',
max_episode_steps=FIXED_RELEASE_STEP, max_episode_steps=FIXED_RELEASE_STEP,
@ -248,7 +176,7 @@ register(
# Table Tennis environments # Table Tennis environments
for ctxt_dim in [2, 4]: for ctxt_dim in [2, 4]:
register( gym_register(
id='TableTennis{}D-v0'.format(ctxt_dim), id='TableTennis{}D-v0'.format(ctxt_dim),
entry_point='fancy_gym.envs.mujoco:TableTennisEnv', entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS, max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
@ -258,13 +186,13 @@ for ctxt_dim in [2, 4]:
} }
) )
register( gym_register(
id='TableTennisWind-v0', id='TableTennisWind-v0',
entry_point='fancy_gym.envs.mujoco:TableTennisWind', entry_point='fancy_gym.envs.mujoco:TableTennisWind',
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS, max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
) )
register( gym_register(
id='TableTennisGoalSwitching-v0', id='TableTennisGoalSwitching-v0',
entry_point='fancy_gym.envs.mujoco:TableTennisGoalSwitching', entry_point='fancy_gym.envs.mujoco:TableTennisGoalSwitching',
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS, max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
@ -276,98 +204,13 @@ register(
# movement Primitive Environments # movement Primitive Environments
## Simple Reacher # Simple Reacher [DONE]
_versions = ["SimpleReacher-v0", "LongSimpleReacher-v0"]
for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}DMP-{_name[1]}'
kwargs_dict_simple_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_simple_reacher_dmp['wrappers'].append(classic_control.simple_reacher.MPWrapper)
kwargs_dict_simple_reacher_dmp['controller_kwargs']['p_gains'] = 0.6
kwargs_dict_simple_reacher_dmp['controller_kwargs']['d_gains'] = 0.075
kwargs_dict_simple_reacher_dmp['trajectory_generator_kwargs']['weight_scale'] = 50
kwargs_dict_simple_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2
kwargs_dict_simple_reacher_dmp['name'] = f"{_v}"
register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_simple_reacher_dmp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'{_name[0]}ProMP-{_name[1]}' # Viapoint reacher [DONE]
kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_simple_reacher_promp['wrappers'].append(classic_control.simple_reacher.MPWrapper)
kwargs_dict_simple_reacher_promp['controller_kwargs']['p_gains'] = 0.6
kwargs_dict_simple_reacher_promp['controller_kwargs']['d_gains'] = 0.075
kwargs_dict_simple_reacher_promp['name'] = _v
register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_simple_reacher_promp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# Viapoint reacher # Hole Reacher [DONE]
kwargs_dict_via_point_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_via_point_reacher_dmp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper)
kwargs_dict_via_point_reacher_dmp['controller_kwargs']['controller_type'] = 'velocity'
kwargs_dict_via_point_reacher_dmp['trajectory_generator_kwargs']['weight_scale'] = 50
kwargs_dict_via_point_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2
kwargs_dict_via_point_reacher_dmp['name'] = "ViaPointReacher-v0"
register(
id='ViaPointReacherDMP-v0',
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# max_episode_steps=1,
kwargs=kwargs_dict_via_point_reacher_dmp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0")
kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP) # ReacherNd
kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper)
kwargs_dict_via_point_reacher_promp['controller_kwargs']['controller_type'] = 'velocity'
kwargs_dict_via_point_reacher_promp['name'] = "ViaPointReacher-v0"
register(
id="ViaPointReacherProMP-v0",
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_via_point_reacher_promp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
## Hole Reacher
_versions = ["HoleReacher-v0"]
for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}DMP-{_name[1]}'
kwargs_dict_hole_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_hole_reacher_dmp['wrappers'].append(classic_control.hole_reacher.MPWrapper)
kwargs_dict_hole_reacher_dmp['controller_kwargs']['controller_type'] = 'velocity'
# TODO: Before it was weight scale 50 and goal scale 0.1. We now only have weight scale and thus set it to 500. Check
kwargs_dict_hole_reacher_dmp['trajectory_generator_kwargs']['weight_scale'] = 500
kwargs_dict_hole_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2.5
kwargs_dict_hole_reacher_dmp['name'] = _v
register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# max_episode_steps=1,
kwargs=kwargs_dict_hole_reacher_dmp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_hole_reacher_promp['wrappers'].append(classic_control.hole_reacher.MPWrapper)
kwargs_dict_hole_reacher_promp['trajectory_generator_kwargs']['weight_scale'] = 2
kwargs_dict_hole_reacher_promp['controller_kwargs']['controller_type'] = 'velocity'
kwargs_dict_hole_reacher_promp['name'] = f"{_v}"
register(
id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_hole_reacher_promp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
## ReacherNd
_versions = ["Reacher5d-v0", "Reacher7d-v0", "Reacher5dSparse-v0", "Reacher7dSparse-v0"] _versions = ["Reacher5d-v0", "Reacher7d-v0", "Reacher5dSparse-v0", "Reacher7dSparse-v0"]
for _v in _versions: for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
@ -376,7 +219,7 @@ for _v in _versions:
kwargs_dict_reacher_dmp['wrappers'].append(mujoco.reacher.MPWrapper) kwargs_dict_reacher_dmp['wrappers'].append(mujoco.reacher.MPWrapper)
kwargs_dict_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2 kwargs_dict_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2
kwargs_dict_reacher_dmp['name'] = _v kwargs_dict_reacher_dmp['name'] = _v
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# max_episode_steps=1, # max_episode_steps=1,
@ -388,14 +231,14 @@ for _v in _versions:
kwargs_dict_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_reacher_promp['wrappers'].append(mujoco.reacher.MPWrapper) kwargs_dict_reacher_promp['wrappers'].append(mujoco.reacher.MPWrapper)
kwargs_dict_reacher_promp['name'] = _v kwargs_dict_reacher_promp['name'] = _v
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_reacher_promp kwargs=kwargs_dict_reacher_promp
) )
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
######################################################################################################################## ########################################################################################################################
## Beerpong ProMP # Beerpong ProMP
_versions = ['BeerPong-v0'] _versions = ['BeerPong-v0']
for _v in _versions: for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
@ -408,14 +251,14 @@ for _v in _versions:
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2 kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2 kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_bp_promp['name'] = _v kwargs_dict_bp_promp['name'] = _v
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_bp_promp kwargs=kwargs_dict_bp_promp
) )
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
### BP with Fixed release # BP with Fixed release
_versions = ["BeerPongStepBased-v0", 'BeerPong-v0'] _versions = ["BeerPongStepBased-v0", 'BeerPong-v0']
for _v in _versions: for _v in _versions:
if _v != 'BeerPong-v0': if _v != 'BeerPong-v0':
@ -431,7 +274,7 @@ for _v in _versions:
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2 kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2 kwargs_dict_bp_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_bp_promp['name'] = _v kwargs_dict_bp_promp['name'] = _v
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_bp_promp kwargs=kwargs_dict_bp_promp
@ -439,7 +282,7 @@ for _v in _versions:
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
######################################################################################################################## ########################################################################################################################
## Table Tennis needs to be fixed according to Zhou's implementation # Table Tennis needs to be fixed according to Zhou's implementation
# TODO: Add later when finished # TODO: Add later when finished
# ######################################################################################################################## # ########################################################################################################################
@ -452,7 +295,7 @@ for _v in _versions:
# kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP) # kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
# kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper) # kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper)
# kwargs_dict_ant_jump_promp['name'] = _v # kwargs_dict_ant_jump_promp['name'] = _v
# register( # gym_register(
# id=_env_id, # id=_env_id,
# entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', # entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# kwargs=kwargs_dict_ant_jump_promp # kwargs=kwargs_dict_ant_jump_promp
@ -469,7 +312,7 @@ for _v in _versions:
# kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP) # kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
# kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper) # kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper)
# kwargs_dict_halfcheetah_jump_promp['name'] = _v # kwargs_dict_halfcheetah_jump_promp['name'] = _v
# register( # gym_register(
# id=_env_id, # id=_env_id,
# entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', # entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# kwargs=kwargs_dict_halfcheetah_jump_promp # kwargs=kwargs_dict_halfcheetah_jump_promp
@ -479,7 +322,7 @@ for _v in _versions:
# ######################################################################################################################## # ########################################################################################################################
## HopperJump # HopperJump
_versions = ['HopperJump-v0', 'HopperJumpSparse-v0', _versions = ['HopperJump-v0', 'HopperJumpSparse-v0',
# 'HopperJumpOnBox-v0', 'HopperThrow-v0', 'HopperThrowInBasket-v0' # 'HopperJumpOnBox-v0', 'HopperThrow-v0', 'HopperThrowInBasket-v0'
] ]
@ -490,7 +333,7 @@ for _v in _versions:
kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_hopper_jump_promp['wrappers'].append(mujoco.hopper_jump.MPWrapper) kwargs_dict_hopper_jump_promp['wrappers'].append(mujoco.hopper_jump.MPWrapper)
kwargs_dict_hopper_jump_promp['name'] = _v kwargs_dict_hopper_jump_promp['name'] = _v
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_hopper_jump_promp kwargs=kwargs_dict_hopper_jump_promp
@ -499,7 +342,7 @@ for _v in _versions:
# ######################################################################################################################## # ########################################################################################################################
## Box Pushing # Box Pushing
_versions = ['BoxPushingDense-v0', 'BoxPushingTemporalSparse-v0', 'BoxPushingTemporalSpatialSparse-v0'] _versions = ['BoxPushingDense-v0', 'BoxPushingTemporalSparse-v0', 'BoxPushingTemporalSpatialSparse-v0']
for _v in _versions: for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
@ -511,7 +354,7 @@ for _v in _versions:
kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2 # 3.5, 4 to try kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2 # 3.5, 4 to try
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_box_pushing_promp kwargs=kwargs_dict_box_pushing_promp
@ -537,14 +380,14 @@ for _v in _versions:
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4 kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0 kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desired'] = True kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desired'] = True
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_box_pushing_prodmp kwargs=kwargs_dict_box_pushing_prodmp
) )
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
## Table Tennis # Table Tennis
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0', 'TableTennisWind-v0', 'TableTennisGoalSwitching-v0'] _versions = ['TableTennis2D-v0', 'TableTennis4D-v0', 'TableTennisWind-v0', 'TableTennisGoalSwitching-v0']
for _v in _versions: for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
@ -565,7 +408,7 @@ for _v in _versions:
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 1 kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 1
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1 kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1
kwargs_dict_tt_promp['black_box_kwargs']['verbose'] = 2 kwargs_dict_tt_promp['black_box_kwargs']['verbose'] = 2
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_tt_promp kwargs=kwargs_dict_tt_promp
@ -595,7 +438,7 @@ for _v in _versions:
kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25. kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25.
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_tt_prodmp kwargs=kwargs_dict_tt_prodmp
@ -625,7 +468,7 @@ for _v in _versions:
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3 kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 50 == 0 kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 50 == 0
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_tt_prodmp kwargs=kwargs_dict_tt_prodmp
@ -640,16 +483,16 @@ for _v in _versions:
# kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP) # kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
# kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper) # kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper)
# kwargs_dict_walker2d_jump_promp['name'] = _v # kwargs_dict_walker2d_jump_promp['name'] = _v
# register( # gym_register(
# id=_env_id, # id=_env_id,
# entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', # entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
# kwargs=kwargs_dict_walker2d_jump_promp # kwargs=kwargs_dict_walker2d_jump_promp
# ) # )
# ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) # ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
### Depricated, we will not provide non random starts anymore # Depricated, we will not provide non random starts anymore
""" """
register( gym_register(
id='SimpleReacher-v1', id='SimpleReacher-v1',
entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv', entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
@ -659,7 +502,7 @@ register(
} }
) )
register( gym_register(
id='LongSimpleReacher-v1', id='LongSimpleReacher-v1',
entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv', entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
@ -668,7 +511,7 @@ register(
"random_start": False "random_start": False
} }
) )
register( gym_register(
id='HoleReacher-v1', id='HoleReacher-v1',
entry_point='fancy_gym.envs.classic_control:HoleReacherEnv', entry_point='fancy_gym.envs.classic_control:HoleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
@ -683,7 +526,7 @@ register(
"collision_penalty": 100, "collision_penalty": 100,
} }
) )
register( gym_register(
id='HoleReacher-v2', id='HoleReacher-v2',
entry_point='fancy_gym.envs.classic_control:HoleReacherEnv', entry_point='fancy_gym.envs.classic_control:HoleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
@ -700,7 +543,7 @@ register(
) )
# CtxtFree are v0, Contextual are v1 # CtxtFree are v0, Contextual are v1
register( gym_register(
id='AntJump-v0', id='AntJump-v0',
entry_point='fancy_gym.envs.mujoco:AntJumpEnv', entry_point='fancy_gym.envs.mujoco:AntJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP, max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
@ -710,7 +553,7 @@ register(
} }
) )
# CtxtFree are v0, Contextual are v1 # CtxtFree are v0, Contextual are v1
register( gym_register(
id='HalfCheetahJump-v0', id='HalfCheetahJump-v0',
entry_point='fancy_gym.envs.mujoco:HalfCheetahJumpEnv', entry_point='fancy_gym.envs.mujoco:HalfCheetahJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP, max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
@ -719,7 +562,7 @@ register(
"context": False "context": False
} }
) )
register( gym_register(
id='HopperJump-v0', id='HopperJump-v0',
entry_point='fancy_gym.envs.mujoco:HopperJumpEnv', entry_point='fancy_gym.envs.mujoco:HopperJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
@ -732,12 +575,12 @@ register(
""" """
### Deprecated used for CorL paper # Deprecated used for CorL paper
""" """
_vs = np.arange(101).tolist() + [1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1] _vs = np.arange(101).tolist() + [1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]
for i in _vs: for i in _vs:
_env_id = f'ALRReacher{i}-v0' _env_id = f'ALRReacher{i}-v0'
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.envs.mujoco:ReacherEnv', entry_point='fancy_gym.envs.mujoco:ReacherEnv',
max_episode_steps=200, max_episode_steps=200,
@ -750,7 +593,7 @@ for i in _vs:
) )
_env_id = f'ALRReacherSparse{i}-v0' _env_id = f'ALRReacherSparse{i}-v0'
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.envs.mujoco:ReacherEnv', entry_point='fancy_gym.envs.mujoco:ReacherEnv',
max_episode_steps=200, max_episode_steps=200,
@ -764,7 +607,7 @@ for i in _vs:
_vs = np.arange(101).tolist() + [1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1] _vs = np.arange(101).tolist() + [1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1, 5e-1]
for i in _vs: for i in _vs:
_env_id = f'ALRReacher{i}ProMP-v0' _env_id = f'ALRReacher{i}ProMP-v0'
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_promp_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
@ -787,7 +630,7 @@ for i in _vs:
) )
_env_id = f'ALRReacherSparse{i}ProMP-v0' _env_id = f'ALRReacherSparse{i}ProMP-v0'
register( gym_register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_promp_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
@ -809,7 +652,7 @@ for i in _vs:
} }
) )
register( gym_register(
id='HopperJumpOnBox-v0', id='HopperJumpOnBox-v0',
entry_point='fancy_gym.envs.mujoco:HopperJumpOnBoxEnv', entry_point='fancy_gym.envs.mujoco:HopperJumpOnBoxEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
@ -818,7 +661,7 @@ for i in _vs:
"context": False "context": False
} }
) )
register( gym_register(
id='HopperThrow-v0', id='HopperThrow-v0',
entry_point='fancy_gym.envs.mujoco:HopperThrowEnv', entry_point='fancy_gym.envs.mujoco:HopperThrowEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW, max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
@ -827,7 +670,7 @@ for i in _vs:
"context": False "context": False
} }
) )
register( gym_register(
id='HopperThrowInBasket-v0', id='HopperThrowInBasket-v0',
entry_point='fancy_gym.envs.mujoco:HopperThrowInBasketEnv', entry_point='fancy_gym.envs.mujoco:HopperThrowInBasketEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET, max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
@ -836,7 +679,7 @@ for i in _vs:
"context": False "context": False
} }
) )
register( gym_register(
id='Walker2DJump-v0', id='Walker2DJump-v0',
entry_point='fancy_gym.envs.mujoco:Walker2dJumpEnv', entry_point='fancy_gym.envs.mujoco:Walker2dJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP, max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
@ -845,12 +688,12 @@ for i in _vs:
"context": False "context": False
} }
) )
register(id='TableTennis2DCtxt-v1', gym_register(id='TableTennis2DCtxt-v1',
entry_point='fancy_gym.envs.mujoco:TTEnvGym', entry_point='fancy_gym.envs.mujoco:TTEnvGym',
max_episode_steps=MAX_EPISODE_STEPS, max_episode_steps=MAX_EPISODE_STEPS,
kwargs={'ctxt_dim': 2, 'fixed_goal': True}) kwargs={'ctxt_dim': 2, 'fixed_goal': True})
register( gym_register(
id='BeerPong-v0', id='BeerPong-v0',
entry_point='fancy_gym.envs.mujoco:BeerBongEnv', entry_point='fancy_gym.envs.mujoco:BeerBongEnv',
max_episode_steps=300, max_episode_steps=300,

View File

@ -8,11 +8,41 @@ from gymnasium.core import ObsType
from matplotlib import patches from matplotlib import patches
from fancy_gym.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv from fancy_gym.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
from . import MPWrapper
MAX_EPISODE_STEPS_HOLEREACHER = 200 MAX_EPISODE_STEPS_HOLEREACHER = 200
class HoleReacherEnv(BaseReacherDirectEnv): class HoleReacherEnv(BaseReacherDirectEnv):
metadata = {
'mp_config': {
'ProMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'controller_type': 'velocity',
},
'trajectory_generator_kwargs': {
'weight_scale': 2,
},
},
'DMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'controller_type': 'velocity',
},
'trajectory_generator_kwargs': {
# TODO: Before it was weight scale 50 and goal scale 0.1. We now only have weight scale and thus set it to 500. Check
'weight_scale': 500,
},
'phase_generator_kwargs': {
'alpha_phase': 2.5,
},
},
'ProDMP': {},
}
}
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None, def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False, hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"): allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"):

View File

@ -6,6 +6,7 @@ from gymnasium import spaces
from gymnasium.core import ObsType from gymnasium.core import ObsType
from fancy_gym.envs.classic_control.base_reacher.base_reacher_torque import BaseReacherTorqueEnv from fancy_gym.envs.classic_control.base_reacher.base_reacher_torque import BaseReacherTorqueEnv
from . import MPWrapper
class SimpleReacherEnv(BaseReacherTorqueEnv): class SimpleReacherEnv(BaseReacherTorqueEnv):
@ -15,6 +16,32 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
towards the end of the trajectory. towards the end of the trajectory.
""" """
metadata = {
'mp_config': {
'ProMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'p_gains': 0.6,
'd_gains': 0.075,
},
},
'DMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'p_gains': 0.6,
'd_gains': 0.075,
},
'trajectory_generator_kwargs': {
'weight_scale': 50,
},
'phase_generator_kwargs': {
'alpha_phase': 2,
},
},
'ProDMP': {},
}
}
def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True, def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
allow_self_collision: bool = False, ): allow_self_collision: bool = False, ):
super().__init__(n_links, random_start, allow_self_collision) super().__init__(n_links, random_start, allow_self_collision)
@ -126,4 +153,3 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
self.fig.canvas.draw() self.fig.canvas.draw()
self.fig.canvas.flush_events() self.fig.canvas.flush_events()

View File

@ -7,10 +7,35 @@ from gymnasium import spaces
from gymnasium.core import ObsType from gymnasium.core import ObsType
from fancy_gym.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv from fancy_gym.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
from . import MPWrapper
class ViaPointReacherEnv(BaseReacherDirectEnv): class ViaPointReacherEnv(BaseReacherDirectEnv):
metadata = {
'mp_config': {
'ProMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'controller_type': 'velocity',
},
},
'DMP': {
'wrappers': [MPWrapper],
'controller_kwargs': {
'controller_type': 'velocity',
},
'trajectory_generator_kwargs': {
'weight_scale': 50,
},
'phase_generator_kwargs': {
'alpha_phase': 2,
},
},
'ProDMP': {},
}
}
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None, def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000): target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
@ -184,5 +209,3 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k') plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k')
plt.pause(0.01) plt.pause(0.01)