refactoring env registration wip
This commit is contained in:
parent
1b061b2a37
commit
dbf2be1006
@ -5,9 +5,14 @@ from gymnasium import register as gym_register
|
|||||||
from .registry import register, ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS
|
from .registry import register, ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS
|
||||||
|
|
||||||
from . import classic_control, mujoco
|
from . import classic_control, mujoco
|
||||||
from .classic_control.hole_reacher.hole_reacher import HoleReacherEnv
|
|
||||||
from .classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
|
from .classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
|
||||||
|
from .classic_control.simple_reacher import MPWrapper as MPWrapper_SimpleReacher
|
||||||
|
from .classic_control.hole_reacher.hole_reacher import HoleReacherEnv
|
||||||
|
from .classic_control.hole_reacher import MPWrapper as MPWrapper_HoleReacher
|
||||||
from .classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacherEnv
|
from .classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacherEnv
|
||||||
|
from .classic_control.viapoint_reacher import MPWrapper as MPWrapper_ViaPointReacher
|
||||||
|
from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
|
||||||
|
from .mujoco.reacher.mp_wrapper import MPWrapper as MPWrapper_Reacher
|
||||||
from .mujoco.ant_jump.ant_jump import MAX_EPISODE_STEPS_ANTJUMP
|
from .mujoco.ant_jump.ant_jump import MAX_EPISODE_STEPS_ANTJUMP
|
||||||
from .mujoco.beerpong.beerpong import MAX_EPISODE_STEPS_BEERPONG, FIXED_RELEASE_STEP
|
from .mujoco.beerpong.beerpong import MAX_EPISODE_STEPS_BEERPONG, FIXED_RELEASE_STEP
|
||||||
from .mujoco.half_cheetah_jump.half_cheetah_jump import MAX_EPISODE_STEPS_HALFCHEETAHJUMP
|
from .mujoco.half_cheetah_jump.half_cheetah_jump import MAX_EPISODE_STEPS_HALFCHEETAHJUMP
|
||||||
@ -15,7 +20,6 @@ from .mujoco.hopper_jump.hopper_jump import MAX_EPISODE_STEPS_HOPPERJUMP
|
|||||||
from .mujoco.hopper_jump.hopper_jump_on_box import MAX_EPISODE_STEPS_HOPPERJUMPONBOX
|
from .mujoco.hopper_jump.hopper_jump_on_box import MAX_EPISODE_STEPS_HOPPERJUMPONBOX
|
||||||
from .mujoco.hopper_throw.hopper_throw import MAX_EPISODE_STEPS_HOPPERTHROW
|
from .mujoco.hopper_throw.hopper_throw import MAX_EPISODE_STEPS_HOPPERTHROW
|
||||||
from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPERTHROWINBASKET
|
from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPERTHROWINBASKET
|
||||||
from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
|
|
||||||
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
|
||||||
from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, \
|
from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, \
|
||||||
BoxPushingTemporalSpatialSparse, MAX_EPISODE_STEPS_BOX_PUSHING
|
BoxPushingTemporalSpatialSparse, MAX_EPISODE_STEPS_BOX_PUSHING
|
||||||
@ -26,7 +30,8 @@ from .mujoco.table_tennis.table_tennis_env import TableTennisEnv, TableTennisWin
|
|||||||
# Simple Reacher
|
# Simple Reacher
|
||||||
register(
|
register(
|
||||||
id='SimpleReacher-v0',
|
id='SimpleReacher-v0',
|
||||||
entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv',
|
entry_point=SimpleReacherEnv,
|
||||||
|
mp_wrapper=MPWrapper_SimpleReacher,
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 2,
|
"n_links": 2,
|
||||||
@ -35,7 +40,8 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='LongSimpleReacher-v0',
|
id='LongSimpleReacher-v0',
|
||||||
entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv',
|
entry_point=SimpleReacherEnv,
|
||||||
|
mp_wrapper=MPWrapper_SimpleReacher,
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -45,7 +51,8 @@ register(
|
|||||||
# Viapoint Reacher
|
# Viapoint Reacher
|
||||||
register(
|
register(
|
||||||
id='ViaPointReacher-v0',
|
id='ViaPointReacher-v0',
|
||||||
entry_point='fancy_gym.envs.classic_control:ViaPointReacherEnv',
|
entry_point=ViaPointReacherEnv,
|
||||||
|
mp_wrapper=MPWrapper_ViaPointReacher,
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -57,7 +64,8 @@ register(
|
|||||||
# Hole Reacher
|
# Hole Reacher
|
||||||
register(
|
register(
|
||||||
id='HoleReacher-v0',
|
id='HoleReacher-v0',
|
||||||
entry_point='fancy_gym.envs.classic_control:HoleReacherEnv',
|
entry_point=HoleReacherEnv,
|
||||||
|
mp_wrapper=MPWrapper_HoleReacher,
|
||||||
max_episode_steps=200,
|
max_episode_steps=200,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
@ -74,39 +82,44 @@ register(
|
|||||||
# Mujoco
|
# Mujoco
|
||||||
|
|
||||||
# Mujoco Reacher
|
# Mujoco Reacher
|
||||||
for _dims in [5, 7]:
|
for dims in [5, 7]:
|
||||||
gym_register(
|
register(
|
||||||
id=f'Reacher{_dims}d-v0',
|
id=f'Reacher{dims}d-v0',
|
||||||
entry_point='fancy_gym.envs.mujoco:ReacherEnv',
|
entry_point=ReacherEnv,
|
||||||
|
mp_wrapper=MPWrapper_Reacher,
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_REACHER,
|
max_episode_steps=MAX_EPISODE_STEPS_REACHER,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_links": _dims,
|
"n_links": dims,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
gym_register(
|
register(
|
||||||
id=f'Reacher{_dims}dSparse-v0',
|
id=f'Reacher{dims}dSparse-v0',
|
||||||
entry_point='fancy_gym.envs.mujoco:ReacherEnv',
|
entry_point=ReacherEnv,
|
||||||
|
mp_wrapper=MPWrapper_Reacher,
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_REACHER,
|
max_episode_steps=MAX_EPISODE_STEPS_REACHER,
|
||||||
kwargs={
|
kwargs={
|
||||||
"sparse": True,
|
"sparse": True,
|
||||||
'reward_weight': 200,
|
'reward_weight': 200,
|
||||||
"n_links": _dims,
|
"n_links": dims,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
gym_register(
|
|
||||||
|
register(
|
||||||
id='HopperJumpSparse-v0',
|
id='HopperJumpSparse-v0',
|
||||||
entry_point='fancy_gym.envs.mujoco:HopperJumpEnv',
|
entry_point='fancy_gym.envs.mujoco:HopperJumpEnv',
|
||||||
|
mp_wrapper=mujoco.hopper_jump.MPWrapper,
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
||||||
kwargs={
|
kwargs={
|
||||||
"sparse": True,
|
"sparse": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
gym_register(
|
register(
|
||||||
id='HopperJump-v0',
|
id='HopperJump-v0',
|
||||||
entry_point='fancy_gym.envs.mujoco:HopperJumpEnv',
|
entry_point='fancy_gym.envs.mujoco:HopperJumpEnv',
|
||||||
|
mp_wrapper=mujoco.hopper_jump.MPWrapper,
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
|
||||||
kwargs={
|
kwargs={
|
||||||
"sparse": False,
|
"sparse": False,
|
||||||
@ -160,9 +173,18 @@ gym_register(
|
|||||||
|
|
||||||
# Box pushing environments with different rewards
|
# Box pushing environments with different rewards
|
||||||
for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]:
|
for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]:
|
||||||
gym_register(
|
register(
|
||||||
id='BoxPushing{}-v0'.format(reward_type),
|
id='BoxPushing{}-v0'.format(reward_type),
|
||||||
entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type),
|
entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type),
|
||||||
|
mp_wrapper=mujoco.box_pushing.MPWrapper,
|
||||||
|
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='BoxPushing{}Replan-v0'.format(reward_type),
|
||||||
|
entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type),
|
||||||
|
mp_wrapper=mujoco.box_pushing.ReplanMPWrapper,
|
||||||
|
register_step_based=False,
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
|
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -202,42 +224,6 @@ gym_register(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# movement Primitive Environments
|
|
||||||
|
|
||||||
# Simple Reacher [DONE]
|
|
||||||
|
|
||||||
# Viapoint reacher [DONE]
|
|
||||||
|
|
||||||
# Hole Reacher [DONE]
|
|
||||||
|
|
||||||
# ReacherNd
|
|
||||||
_versions = ["Reacher5d-v0", "Reacher7d-v0", "Reacher5dSparse-v0", "Reacher7dSparse-v0"]
|
|
||||||
for _v in _versions:
|
|
||||||
_name = _v.split("-")
|
|
||||||
_env_id = f'{_name[0]}DMP-{_name[1]}'
|
|
||||||
kwargs_dict_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
|
||||||
kwargs_dict_reacher_dmp['wrappers'].append(mujoco.reacher.MPWrapper)
|
|
||||||
kwargs_dict_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
|
||||||
kwargs_dict_reacher_dmp['name'] = _v
|
|
||||||
gym_register(
|
|
||||||
id=_env_id,
|
|
||||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
|
||||||
# max_episode_steps=1,
|
|
||||||
kwargs=kwargs_dict_reacher_dmp
|
|
||||||
)
|
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
|
||||||
|
|
||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
|
||||||
kwargs_dict_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
|
||||||
kwargs_dict_reacher_promp['wrappers'].append(mujoco.reacher.MPWrapper)
|
|
||||||
kwargs_dict_reacher_promp['name'] = _v
|
|
||||||
gym_register(
|
|
||||||
id=_env_id,
|
|
||||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
|
||||||
kwargs=kwargs_dict_reacher_promp
|
|
||||||
)
|
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
|
||||||
########################################################################################################################
|
|
||||||
# Beerpong ProMP
|
# Beerpong ProMP
|
||||||
_versions = ['BeerPong-v0']
|
_versions = ['BeerPong-v0']
|
||||||
for _v in _versions:
|
for _v in _versions:
|
||||||
@ -321,72 +307,6 @@ for _v in _versions:
|
|||||||
#
|
#
|
||||||
# ########################################################################################################################
|
# ########################################################################################################################
|
||||||
|
|
||||||
|
|
||||||
# HopperJump
|
|
||||||
_versions = ['HopperJump-v0', 'HopperJumpSparse-v0',
|
|
||||||
# 'HopperJumpOnBox-v0', 'HopperThrow-v0', 'HopperThrowInBasket-v0'
|
|
||||||
]
|
|
||||||
# TODO: Check if all environments work with the same MPWrapper
|
|
||||||
for _v in _versions:
|
|
||||||
_name = _v.split("-")
|
|
||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
|
||||||
kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
|
||||||
kwargs_dict_hopper_jump_promp['wrappers'].append(mujoco.hopper_jump.MPWrapper)
|
|
||||||
kwargs_dict_hopper_jump_promp['name'] = _v
|
|
||||||
gym_register(
|
|
||||||
id=_env_id,
|
|
||||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
|
||||||
kwargs=kwargs_dict_hopper_jump_promp
|
|
||||||
)
|
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
|
||||||
|
|
||||||
# ########################################################################################################################
|
|
||||||
|
|
||||||
# Box Pushing
|
|
||||||
_versions = ['BoxPushingDense-v0', 'BoxPushingTemporalSparse-v0', 'BoxPushingTemporalSpatialSparse-v0']
|
|
||||||
for _v in _versions:
|
|
||||||
_name = _v.split("-")
|
|
||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
|
||||||
kwargs_dict_box_pushing_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
|
||||||
kwargs_dict_box_pushing_promp['wrappers'].append(mujoco.box_pushing.MPWrapper)
|
|
||||||
kwargs_dict_box_pushing_promp['name'] = _v
|
|
||||||
kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
|
||||||
kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
|
||||||
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2 # 3.5, 4 to try
|
|
||||||
|
|
||||||
gym_register(
|
|
||||||
id=_env_id,
|
|
||||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
|
||||||
kwargs=kwargs_dict_box_pushing_promp
|
|
||||||
)
|
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
|
||||||
|
|
||||||
for _v in _versions:
|
|
||||||
_name = _v.split("-")
|
|
||||||
_env_id = f'{_name[0]}ReplanProDMP-{_name[1]}'
|
|
||||||
kwargs_dict_box_pushing_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
|
||||||
kwargs_dict_box_pushing_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
|
|
||||||
kwargs_dict_box_pushing_prodmp['name'] = _v
|
|
||||||
kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
|
||||||
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3
|
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3
|
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = True
|
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
|
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['disable_goal'] = True
|
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 5
|
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
|
|
||||||
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4
|
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0
|
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desired'] = True
|
|
||||||
gym_register(
|
|
||||||
id=_env_id,
|
|
||||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
|
||||||
kwargs=kwargs_dict_box_pushing_prodmp
|
|
||||||
)
|
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
|
||||||
|
|
||||||
# Table Tennis
|
# Table Tennis
|
||||||
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0', 'TableTennisWind-v0', 'TableTennisGoalSwitching-v0']
|
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0', 'TableTennisWind-v0', 'TableTennisGoalSwitching-v0']
|
||||||
for _v in _versions:
|
for _v in _versions:
|
||||||
|
Loading…
Reference in New Issue
Block a user