From dbf2be1006015e72f167b1114713ed7a4a3d4265 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 20 Jul 2023 11:45:53 +0200 Subject: [PATCH] refactoring env registration wip --- fancy_gym/envs/__init__.py | 160 ++++++++++--------------------------- 1 file changed, 40 insertions(+), 120 deletions(-) diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index 2de5d10..62fe5b7 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -5,9 +5,14 @@ from gymnasium import register as gym_register from .registry import register, ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS from . import classic_control, mujoco -from .classic_control.hole_reacher.hole_reacher import HoleReacherEnv from .classic_control.simple_reacher.simple_reacher import SimpleReacherEnv +from .classic_control.simple_reacher import MPWrapper as MPWrapper_SimpleReacher +from .classic_control.hole_reacher.hole_reacher import HoleReacherEnv +from .classic_control.hole_reacher import MPWrapper as MPWrapper_HoleReacher from .classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacherEnv +from .classic_control.viapoint_reacher import MPWrapper as MPWrapper_ViaPointReacher +from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER +from .mujoco.reacher.mp_wrapper import MPWrapper as MPWrapper_Reacher from .mujoco.ant_jump.ant_jump import MAX_EPISODE_STEPS_ANTJUMP from .mujoco.beerpong.beerpong import MAX_EPISODE_STEPS_BEERPONG, FIXED_RELEASE_STEP from .mujoco.half_cheetah_jump.half_cheetah_jump import MAX_EPISODE_STEPS_HALFCHEETAHJUMP @@ -15,7 +20,6 @@ from .mujoco.hopper_jump.hopper_jump import MAX_EPISODE_STEPS_HOPPERJUMP from .mujoco.hopper_jump.hopper_jump_on_box import MAX_EPISODE_STEPS_HOPPERJUMPONBOX from .mujoco.hopper_throw.hopper_throw import MAX_EPISODE_STEPS_HOPPERTHROW from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPERTHROWINBASKET -from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, \ BoxPushingTemporalSpatialSparse, MAX_EPISODE_STEPS_BOX_PUSHING @@ -26,7 +30,8 @@ from .mujoco.table_tennis.table_tennis_env import TableTennisEnv, TableTennisWin # Simple Reacher register( id='SimpleReacher-v0', - entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv', + entry_point=SimpleReacherEnv, + mp_wrapper=MPWrapper_SimpleReacher, max_episode_steps=200, kwargs={ "n_links": 2, @@ -35,7 +40,8 @@ register( register( id='LongSimpleReacher-v0', - entry_point='fancy_gym.envs.classic_control:SimpleReacherEnv', + entry_point=SimpleReacherEnv, + mp_wrapper=MPWrapper_SimpleReacher, max_episode_steps=200, kwargs={ "n_links": 5, @@ -45,7 +51,8 @@ register( # Viapoint Reacher register( id='ViaPointReacher-v0', - entry_point='fancy_gym.envs.classic_control:ViaPointReacherEnv', + entry_point=ViaPointReacherEnv, + mp_wrapper=MPWrapper_ViaPointReacher, max_episode_steps=200, kwargs={ "n_links": 5, @@ -57,7 +64,8 @@ register( # Hole Reacher register( id='HoleReacher-v0', - entry_point='fancy_gym.envs.classic_control:HoleReacherEnv', + entry_point=HoleReacherEnv, + mp_wrapper=MPWrapper_HoleReacher, max_episode_steps=200, kwargs={ "n_links": 5, @@ -74,39 +82,44 @@ register( # Mujoco # Mujoco Reacher -for _dims in [5, 7]: - gym_register( - id=f'Reacher{_dims}d-v0', - entry_point='fancy_gym.envs.mujoco:ReacherEnv', +for dims in [5, 7]: + register( + id=f'Reacher{dims}d-v0', + entry_point=ReacherEnv, + mp_wrapper=MPWrapper_Reacher, max_episode_steps=MAX_EPISODE_STEPS_REACHER, kwargs={ - "n_links": _dims, + "n_links": dims, } ) - gym_register( - id=f'Reacher{_dims}dSparse-v0', - entry_point='fancy_gym.envs.mujoco:ReacherEnv', + register( + id=f'Reacher{dims}dSparse-v0', + entry_point=ReacherEnv, + mp_wrapper=MPWrapper_Reacher, max_episode_steps=MAX_EPISODE_STEPS_REACHER, kwargs={ "sparse": True, 'reward_weight': 200, - "n_links": _dims, + "n_links": dims, } ) -gym_register( + +register( id='HopperJumpSparse-v0', entry_point='fancy_gym.envs.mujoco:HopperJumpEnv', + mp_wrapper=mujoco.hopper_jump.MPWrapper, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, kwargs={ "sparse": True, } ) -gym_register( +register( id='HopperJump-v0', entry_point='fancy_gym.envs.mujoco:HopperJumpEnv', + mp_wrapper=mujoco.hopper_jump.MPWrapper, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, kwargs={ "sparse": False, @@ -160,9 +173,18 @@ gym_register( # Box pushing environments with different rewards for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]: - gym_register( + register( id='BoxPushing{}-v0'.format(reward_type), entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type), + mp_wrapper=mujoco.box_pushing.MPWrapper, + max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING, + ) + + register( + id='BoxPushing{}Replan-v0'.format(reward_type), + entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type), + mp_wrapper=mujoco.box_pushing.ReplanMPWrapper, + register_step_based=False, max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING, ) @@ -202,42 +224,6 @@ gym_register( ) -# movement Primitive Environments - -# Simple Reacher [DONE] - -# Viapoint reacher [DONE] - -# Hole Reacher [DONE] - -# ReacherNd -_versions = ["Reacher5d-v0", "Reacher7d-v0", "Reacher5dSparse-v0", "Reacher7dSparse-v0"] -for _v in _versions: - _name = _v.split("-") - _env_id = f'{_name[0]}DMP-{_name[1]}' - kwargs_dict_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP) - kwargs_dict_reacher_dmp['wrappers'].append(mujoco.reacher.MPWrapper) - kwargs_dict_reacher_dmp['phase_generator_kwargs']['alpha_phase'] = 2 - kwargs_dict_reacher_dmp['name'] = _v - gym_register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - # max_episode_steps=1, - kwargs=kwargs_dict_reacher_dmp - ) - ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) - - _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP) - kwargs_dict_reacher_promp['wrappers'].append(mujoco.reacher.MPWrapper) - kwargs_dict_reacher_promp['name'] = _v - gym_register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_reacher_promp - ) - ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) -######################################################################################################################## # Beerpong ProMP _versions = ['BeerPong-v0'] for _v in _versions: @@ -321,72 +307,6 @@ for _v in _versions: # # ######################################################################################################################## - -# HopperJump -_versions = ['HopperJump-v0', 'HopperJumpSparse-v0', - # 'HopperJumpOnBox-v0', 'HopperThrow-v0', 'HopperThrowInBasket-v0' - ] -# TODO: Check if all environments work with the same MPWrapper -for _v in _versions: - _name = _v.split("-") - _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP) - kwargs_dict_hopper_jump_promp['wrappers'].append(mujoco.hopper_jump.MPWrapper) - kwargs_dict_hopper_jump_promp['name'] = _v - gym_register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_hopper_jump_promp - ) - ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) - -# ######################################################################################################################## - -# Box Pushing -_versions = ['BoxPushingDense-v0', 'BoxPushingTemporalSparse-v0', 'BoxPushingTemporalSpatialSparse-v0'] -for _v in _versions: - _name = _v.split("-") - _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_box_pushing_promp = deepcopy(DEFAULT_BB_DICT_ProMP) - kwargs_dict_box_pushing_promp['wrappers'].append(mujoco.box_pushing.MPWrapper) - kwargs_dict_box_pushing_promp['name'] = _v - kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) - kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) - kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2 # 3.5, 4 to try - - gym_register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_box_pushing_promp - ) - ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) - -for _v in _versions: - _name = _v.split("-") - _env_id = f'{_name[0]}ReplanProDMP-{_name[1]}' - kwargs_dict_box_pushing_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) - kwargs_dict_box_pushing_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper) - kwargs_dict_box_pushing_prodmp['name'] = _v - kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) - kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) - kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 - kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 - kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = True - kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0 - kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['disable_goal'] = True - kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 5 - kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 - kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 - kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4 - kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0 - kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desired'] = True - gym_register( - id=_env_id, - entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', - kwargs=kwargs_dict_box_pushing_prodmp - ) - ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) - # Table Tennis _versions = ['TableTennis2D-v0', 'TableTennis4D-v0', 'TableTennisWind-v0', 'TableTennisGoalSwitching-v0'] for _v in _versions: