From fea2ae7d11548654e505f8e28a8d3b0748adcc83 Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 30 Jun 2022 17:33:05 +0200 Subject: [PATCH] current state --- alr_envs/alr/__init__.py | 190 +++++------------- .../hole_reacher/mp_wrapper.py | 27 +++ .../simple_reacher/mp_wrapper.py | 2 +- .../viapoint_reacher/mp_wrapper.py | 2 +- alr_envs/alr/mujoco/__init__.py | 2 +- alr_envs/alr/mujoco/ant_jump/mp_wrapper.py | 2 +- .../alr/mujoco/ant_jump/new_mp_wrapper.py | 4 +- .../ball_in_a_cup/ball_in_a_cup_mp_wrapper.py | 2 +- alr_envs/alr/mujoco/beerpong/mp_wrapper.py | 2 +- .../alr/mujoco/beerpong/new_mp_wrapper.py | 2 +- .../mujoco/half_cheetah_jump/mp_wrapper.py | 2 +- .../half_cheetah_jump/new_mp_wrapper.py | 2 +- alr_envs/alr/mujoco/hopper_jump/__init__.py | 2 +- .../alr/mujoco/hopper_jump/hopper_jump.py | 147 +++++++------- alr_envs/alr/mujoco/hopper_jump/mp_wrapper.py | 52 +---- .../alr/mujoco/hopper_jump/new_mp_wrapper.py | 45 ----- .../alr/mujoco/hopper_throw/hopper_throw.py | 2 +- .../alr/mujoco/hopper_throw/mp_wrapper.py | 2 +- .../alr/mujoco/hopper_throw/new_mp_wrapper.py | 2 +- alr_envs/alr/mujoco/reacher/alr_reacher.py | 152 -------------- alr_envs/alr/mujoco/reacher/balancing.py | 53 ----- alr_envs/alr/mujoco/reacher/mp_wrapper.py | 2 +- alr_envs/alr/mujoco/reacher/reacher.py | 105 ++++++++++ .../alr/mujoco/table_tennis/mp_wrapper.py | 2 +- .../alr/mujoco/walker_2d_jump/mp_wrapper.py | 2 +- .../mujoco/walker_2d_jump/new_mp_wrapper.py | 2 +- alr_envs/{mp => black_box}/__init__.py | 0 .../{mp => black_box}/black_box_wrapper.py | 14 +- .../controller}/__init__.py | 0 .../controller}/base_controller.py | 0 .../controller}/controller_factory.py | 8 +- .../controller}/meta_world_controller.py | 2 +- .../controller}/pd_controller.py | 2 +- .../controller}/pos_controller.py | 2 +- .../controller}/vel_controller.py | 2 +- alr_envs/black_box/factory/__init__.py | 0 .../factory}/basis_generator_factory.py | 0 .../factory}/phase_generator_factory.py | 2 +- .../factory/trajectory_generator_factory.py} | 4 +- .../raw_interface_wrapper.py | 0 .../dmc/manipulation/reach_site/mp_wrapper.py | 2 +- alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py | 2 +- alr_envs/dmc/suite/cartpole/mp_wrapper.py | 2 +- alr_envs/dmc/suite/reacher/mp_wrapper.py | 2 +- alr_envs/meta/goal_change_mp_wrapper.py | 2 +- .../goal_endeffector_change_mp_wrapper.py | 2 +- .../meta/goal_object_change_mp_wrapper.py | 2 +- alr_envs/meta/object_change_mp_wrapper.py | 2 +- .../continuous_mountain_car/mp_wrapper.py | 2 +- .../open_ai/mujoco/reacher_v2/mp_wrapper.py | 2 +- alr_envs/open_ai/robotics/fetch/mp_wrapper.py | 2 +- alr_envs/utils/make_env_helpers.py | 15 +- 52 files changed, 325 insertions(+), 557 deletions(-) create mode 100644 alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py delete mode 100644 alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py delete mode 100644 alr_envs/alr/mujoco/reacher/alr_reacher.py delete mode 100644 alr_envs/alr/mujoco/reacher/balancing.py create mode 100644 alr_envs/alr/mujoco/reacher/reacher.py rename alr_envs/{mp => black_box}/__init__.py (100%) rename alr_envs/{mp => black_box}/black_box_wrapper.py (94%) rename alr_envs/{mp/controllers => black_box/controller}/__init__.py (100%) rename alr_envs/{mp/controllers => black_box/controller}/base_controller.py (100%) rename alr_envs/{mp/controllers => black_box/controller}/controller_factory.py (60%) rename alr_envs/{mp/controllers => black_box/controller}/meta_world_controller.py (92%) rename alr_envs/{mp/controllers => black_box/controller}/pd_controller.py (93%) rename alr_envs/{mp/controllers => black_box/controller}/pos_controller.py (77%) rename alr_envs/{mp/controllers => black_box/controller}/vel_controller.py (77%) create mode 100644 alr_envs/black_box/factory/__init__.py rename alr_envs/{mp => black_box/factory}/basis_generator_factory.py (100%) rename alr_envs/{mp => black_box/factory}/phase_generator_factory.py (93%) rename alr_envs/{mp/mp_factory.py => black_box/factory/trajectory_generator_factory.py} (91%) rename alr_envs/{mp => black_box}/raw_interface_wrapper.py (100%) diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 435cfdb..cf009a4 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -1,33 +1,31 @@ -import numpy as np -from gym import register from copy import deepcopy +import numpy as np +from gym import register + +from alr_envs.alr.mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS from . import classic_control, mujoco from .classic_control.hole_reacher.hole_reacher import HoleReacherEnv from .classic_control.simple_reacher.simple_reacher import SimpleReacherEnv from .classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacherEnv +from .mujoco.ant_jump.ant_jump import MAX_EPISODE_STEPS_ANTJUMP from .mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv -from .mujoco.reacher.alr_reacher import ALRReacherEnv -from .mujoco.reacher.balancing import BalancingEnv - -from alr_envs.alr.mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS -from .mujoco.ant_jump.ant_jump import MAX_EPISODE_STEPS_ANTJUMP from .mujoco.half_cheetah_jump.half_cheetah_jump import MAX_EPISODE_STEPS_HALFCHEETAHJUMP from .mujoco.hopper_jump.hopper_jump import MAX_EPISODE_STEPS_HOPPERJUMP from .mujoco.hopper_jump.hopper_jump_on_box import MAX_EPISODE_STEPS_HOPPERJUMPONBOX from .mujoco.hopper_throw.hopper_throw import MAX_EPISODE_STEPS_HOPPERTHROW from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPERTHROWINBASKET +from .mujoco.reacher.reacher import ReacherEnv from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []} -DEFAULT_MP_ENV_DICT = { +DEFAULT_BB_DICT = { "name": 'EnvName', "wrappers": [], "traj_gen_kwargs": { - "weight_scale": 1, - 'movement_primitives_type': 'promp' + 'trajectory_generator_type': 'promp' }, "phase_generator_kwargs": { 'phase_generator_type': 'linear', @@ -100,80 +98,47 @@ register( # Mujoco ## Reacher +for _dims in [5, 7]: + register( + id=f'Reacher{_dims}d-v0', + entry_point='alr_envs.alr.mujoco:ReacherEnv', + max_episode_steps=200, + kwargs={ + "n_links": _dims, + } + ) + + register( + id=f'Reacher{_dims}dSparse-v0', + entry_point='alr_envs.alr.mujoco:ReacherEnv', + max_episode_steps=200, + kwargs={ + "sparse": True, + "n_links": _dims, + } + ) + +## Hopper Jump random joints and desired position register( - id='ALRReacher-v0', - entry_point='alr_envs.alr.mujoco:ALRReacherEnv', - max_episode_steps=200, + id='HopperJumpSparse-v0', + entry_point='alr_envs.alr.mujoco:ALRHopperXYJumpEnv', + max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, kwargs={ - "steps_before_reward": 0, - "n_links": 5, - "balance": False, + # "max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMP, + "sparse": True, + # "healthy_reward": 1.0 } ) +## Hopper Jump random joints and desired position step based reward register( - id='ALRReacherSparse-v0', - entry_point='alr_envs.alr.mujoco:ALRReacherEnv', - max_episode_steps=200, + id='HopperJump-v0', + entry_point='alr_envs.alr.mujoco:ALRHopperXYJumpEnvStepBased', + max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, kwargs={ - "steps_before_reward": 200, - "n_links": 5, - "balance": False, - } -) - -register( - id='ALRReacherSparseOptCtrl-v0', - entry_point='alr_envs.alr.mujoco:ALRReacherOptCtrlEnv', - max_episode_steps=200, - kwargs={ - "steps_before_reward": 200, - "n_links": 5, - "balance": False, - } -) - -register( - id='ALRReacherSparseBalanced-v0', - entry_point='alr_envs.alr.mujoco:ALRReacherEnv', - max_episode_steps=200, - kwargs={ - "steps_before_reward": 200, - "n_links": 5, - "balance": True, - } -) - -register( - id='ALRLongReacher-v0', - entry_point='alr_envs.alr.mujoco:ALRReacherEnv', - max_episode_steps=200, - kwargs={ - "steps_before_reward": 0, - "n_links": 7, - "balance": False, - } -) - -register( - id='ALRLongReacherSparse-v0', - entry_point='alr_envs.alr.mujoco:ALRReacherEnv', - max_episode_steps=200, - kwargs={ - "steps_before_reward": 200, - "n_links": 7, - "balance": False, - } -) - -register( - id='ALRLongReacherSparseBalanced-v0', - entry_point='alr_envs.alr.mujoco:ALRReacherEnv', - max_episode_steps=200, - kwargs={ - "steps_before_reward": 200, - "n_links": 7, - "balance": True, + # "max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMP, + "sparse": False, + # "healthy_reward": 1.0 } ) @@ -198,41 +163,7 @@ register( ) register( - id='ALRHopperJump-v0', - entry_point='alr_envs.alr.mujoco:ALRHopperJumpEnv', - max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, - kwargs={ - "max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMP, - "context": True - } -) - -#### Hopper Jump random joints and des position -register( - id='ALRHopperJumpRndmJointsDesPos-v0', - entry_point='alr_envs.alr.mujoco:ALRHopperXYJumpEnv', - max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, - kwargs={ - "max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMP, - "context": True, - "healthy_reward": 1.0 - } -) - -##### Hopper Jump random joints and des position step based reward -register( - id='ALRHopperJumpRndmJointsDesPosStepBased-v0', - entry_point='alr_envs.alr.mujoco:ALRHopperXYJumpEnvStepBased', - max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, - kwargs={ - "max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMP, - "context": True, - "healthy_reward": 1.0 - } -) - -register( - id='ALRHopperJumpOnBox-v0', + id='HopperJumpOnBox-v0', entry_point='alr_envs.alr.mujoco:ALRHopperJumpOnBoxEnv', max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX, kwargs={ @@ -271,17 +202,6 @@ register( } ) -## Balancing Reacher - -register( - id='Balancing-v0', - entry_point='alr_envs.alr.mujoco:BalancingEnv', - max_episode_steps=200, - kwargs={ - "n_links": 5, - } -) - ## Table Tennis register(id='TableTennis2DCtxt-v0', entry_point='alr_envs.alr.mujoco:TTEnvGym', @@ -361,7 +281,7 @@ for _v in _versions: ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) + kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_BB_DICT) kwargs_dict_simple_reacher_promp['wrappers'].append(classic_control.simple_reacher.MPWrapper) kwargs_dict_simple_reacher_promp['controller_kwargs']['p_gains'] = 0.6 kwargs_dict_simple_reacher_promp['controller_kwargs']['d_gains'] = 0.075 @@ -394,7 +314,7 @@ register( ) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0") -kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) +kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_BB_DICT) kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper) kwargs_dict_via_point_reacher_promp['controller_kwargs']['controller_type'] = 'velocity' kwargs_dict_via_point_reacher_promp['name'] = "ViaPointReacherProMP-v0" @@ -433,7 +353,7 @@ for _v in _versions: ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) + kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_BB_DICT) kwargs_dict_hole_reacher_promp['wrappers'].append(classic_control.hole_reacher.MPWrapper) kwargs_dict_hole_reacher_promp['traj_gen_kwargs']['weight_scale'] = 2 kwargs_dict_hole_reacher_promp['controller_kwargs']['controller_type'] = 'velocity' @@ -475,7 +395,7 @@ for _v in _versions: ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_MP_ENV_DICT) + kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_BB_DICT) kwargs_dict_alr_reacher_promp['wrappers'].append(mujoco.reacher.MPWrapper) kwargs_dict_alr_reacher_promp['controller_kwargs']['p_gains'] = 1 kwargs_dict_alr_reacher_promp['controller_kwargs']['d_gains'] = 0.1 @@ -493,7 +413,7 @@ _versions = ['ALRBeerPong-v0'] for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_bp_promp = deepcopy(DEFAULT_MP_ENV_DICT) + kwargs_dict_bp_promp = deepcopy(DEFAULT_BB_DICT) kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.MPWrapper) kwargs_dict_bp_promp['phase_generator_kwargs']['learn_tau'] = True kwargs_dict_bp_promp['controller_kwargs']['p_gains'] = np.array([1.5, 5, 2.55, 3, 2., 2, 1.25]) @@ -513,7 +433,7 @@ _versions = ["ALRBeerPongStepBased-v0", "ALRBeerPongFixedRelease-v0"] for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_bp_promp = deepcopy(DEFAULT_MP_ENV_DICT) + kwargs_dict_bp_promp = deepcopy(DEFAULT_BB_DICT) kwargs_dict_bp_promp['wrappers'].append(mujoco.beerpong.MPWrapper) kwargs_dict_bp_promp['phase_generator_kwargs']['tau'] = 0.62 kwargs_dict_bp_promp['controller_kwargs']['p_gains'] = np.array([1.5, 5, 2.55, 3, 2., 2, 1.25]) @@ -538,7 +458,7 @@ _versions = ['ALRAntJump-v0'] for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) + kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_BB_DICT) kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper) kwargs_dict_ant_jump_promp['name'] = f"alr_envs:{_v}" register( @@ -555,7 +475,7 @@ _versions = ['ALRHalfCheetahJump-v0'] for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) + kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_BB_DICT) kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper) kwargs_dict_halfcheetah_jump_promp['name'] = f"alr_envs:{_v}" register( @@ -575,7 +495,7 @@ _versions = ['ALRHopperJump-v0', 'ALRHopperJumpRndmJointsDesPos-v0', 'ALRHopperJ for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) + kwargs_dict_hopper_jump_promp = deepcopy(DEFAULT_BB_DICT) kwargs_dict_hopper_jump_promp['wrappers'].append(mujoco.hopper_jump.MPWrapper) kwargs_dict_hopper_jump_promp['name'] = f"alr_envs:{_v}" register( @@ -593,7 +513,7 @@ _versions = ['ALRWalker2DJump-v0'] for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ProMP-{_name[1]}' - kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_MP_ENV_DICT) + kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_BB_DICT) kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper) kwargs_dict_walker2d_jump_promp['name'] = f"alr_envs:{_v}" register( @@ -695,7 +615,7 @@ for i in _vs: _env_id = f'ALRReacher{i}-v0' register( id=_env_id, - entry_point='alr_envs.alr.mujoco:ALRReacherEnv', + entry_point='alr_envs.alr.mujoco:ReacherEnv', max_episode_steps=200, kwargs={ "steps_before_reward": 0, @@ -708,7 +628,7 @@ for i in _vs: _env_id = f'ALRReacherSparse{i}-v0' register( id=_env_id, - entry_point='alr_envs.alr.mujoco:ALRReacherEnv', + entry_point='alr_envs.alr.mujoco:ReacherEnv', max_episode_steps=200, kwargs={ "steps_before_reward": 200, diff --git a/alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py b/alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py new file mode 100644 index 0000000..19bb0c5 --- /dev/null +++ b/alr_envs/alr/classic_control/hole_reacher/mp_wrapper.py @@ -0,0 +1,27 @@ +from typing import Tuple, Union + +import numpy as np + +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper + + +class MPWrapper(RawInterfaceWrapper): + + def get_context_mask(self): + return np.hstack([ + [self.env.random_start] * self.env.n_links, # cos + [self.env.random_start] * self.env.n_links, # sin + [self.env.random_start] * self.env.n_links, # velocity + [self.env.initial_width is None], # hole width + # [self.env.hole_depth is None], # hole depth + [True] * 2, # x-y coordinates of target distance + [False] # env steps + ]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.current_pos + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.current_vel diff --git a/alr_envs/alr/classic_control/simple_reacher/mp_wrapper.py b/alr_envs/alr/classic_control/simple_reacher/mp_wrapper.py index 30b0985..c5ef66f 100644 --- a/alr_envs/alr/classic_control/simple_reacher/mp_wrapper.py +++ b/alr_envs/alr/classic_control/simple_reacher/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py b/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py index 9f40292..2b210de 100644 --- a/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py +++ b/alr_envs/alr/classic_control/viapoint_reacher/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/__init__.py b/alr_envs/alr/mujoco/__init__.py index f2f4536..df52cfc 100644 --- a/alr_envs/alr/mujoco/__init__.py +++ b/alr_envs/alr/mujoco/__init__.py @@ -6,7 +6,7 @@ from .half_cheetah_jump.half_cheetah_jump import ALRHalfCheetahJumpEnv from .hopper_jump.hopper_jump_on_box import ALRHopperJumpOnBoxEnv from .hopper_throw.hopper_throw import ALRHopperThrowEnv from .hopper_throw.hopper_throw_in_basket import ALRHopperThrowInBasketEnv -from .reacher.alr_reacher import ALRReacherEnv +from .reacher.reacher import ReacherEnv from .reacher.balancing import BalancingEnv from .table_tennis.tt_gym import TTEnvGym from .walker_2d_jump.walker_2d_jump import ALRWalker2dJumpEnv diff --git a/alr_envs/alr/mujoco/ant_jump/mp_wrapper.py b/alr_envs/alr/mujoco/ant_jump/mp_wrapper.py index 4d5c0d6..f6e99a3 100644 --- a/alr_envs/alr/mujoco/ant_jump/mp_wrapper.py +++ b/alr_envs/alr/mujoco/ant_jump/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py index 0886065..f6f026b 100644 --- a/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py @@ -1,8 +1,8 @@ -from alr_envs.mp.black_box_wrapper import BlackBoxWrapper +from alr_envs.black_box.black_box_wrapper import BlackBoxWrapper from typing import Union, Tuple import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py b/alr_envs/alr/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py index 609858b..81a08f5 100644 --- a/alr_envs/alr/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py +++ b/alr_envs/alr/mujoco/ball_in_a_cup/ball_in_a_cup_mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class BallInACupMPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/beerpong/mp_wrapper.py b/alr_envs/alr/mujoco/beerpong/mp_wrapper.py index 40c371b..20e7532 100644 --- a/alr_envs/alr/mujoco/beerpong/mp_wrapper.py +++ b/alr_envs/alr/mujoco/beerpong/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py b/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py index 2969b82..bd22442 100644 --- a/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Union, Tuple import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/half_cheetah_jump/mp_wrapper.py b/alr_envs/alr/mujoco/half_cheetah_jump/mp_wrapper.py index f9a298a..930da6d 100644 --- a/alr_envs/alr/mujoco/half_cheetah_jump/mp_wrapper.py +++ b/alr_envs/alr/mujoco/half_cheetah_jump/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py index d14c9a9..9a65952 100644 --- a/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/half_cheetah_jump/new_mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/hopper_jump/__init__.py b/alr_envs/alr/mujoco/hopper_jump/__init__.py index 8a04a02..c5e6d2f 100644 --- a/alr_envs/alr/mujoco/hopper_jump/__init__.py +++ b/alr_envs/alr/mujoco/hopper_jump/__init__.py @@ -1 +1 @@ -from .new_mp_wrapper import MPWrapper +from .mp_wrapper import MPWrapper diff --git a/alr_envs/alr/mujoco/hopper_jump/hopper_jump.py b/alr_envs/alr/mujoco/hopper_jump/hopper_jump.py index 025bb8d..78b06d3 100644 --- a/alr_envs/alr/mujoco/hopper_jump/hopper_jump.py +++ b/alr_envs/alr/mujoco/hopper_jump/hopper_jump.py @@ -1,3 +1,5 @@ +from typing import Optional + from gym.envs.mujoco.hopper_v3 import HopperEnv import numpy as np import os @@ -8,10 +10,10 @@ MAX_EPISODE_STEPS_HOPPERJUMP = 250 class ALRHopperJumpEnv(HopperEnv): """ Initialization changes to normal Hopper: - - healthy_reward: 1.0 -> 0.1 -> 0 - - healthy_angle_range: (-0.2, 0.2) -> (-float('inf'), float('inf')) + - terminate_when_unhealthy: True -> False - healthy_z_range: (0.7, float('inf')) -> (0.5, float('inf')) - - exclude current positions from observatiosn is set to False + - healthy_angle_range: (-0.2, 0.2) -> (-float('inf'), float('inf')) + - exclude_current_positions_from_observation: True -> False """ def __init__( @@ -19,76 +21,93 @@ class ALRHopperJumpEnv(HopperEnv): xml_file='hopper_jump.xml', forward_reward_weight=1.0, ctrl_cost_weight=1e-3, - healthy_reward=0.0, + healthy_reward=1.0, penalty=0.0, - context=True, terminate_when_unhealthy=False, healthy_state_range=(-100.0, 100.0), healthy_z_range=(0.5, float('inf')), healthy_angle_range=(-float('inf'), float('inf')), reset_noise_scale=5e-3, exclude_current_positions_from_observation=False, - max_episode_steps=250 - ): + ): - self.current_step = 0 + self._steps = 0 self.max_height = 0 - self.max_episode_steps = max_episode_steps - self.penalty = penalty + # self.penalty = penalty self.goal = 0 - self.context = context - self.exclude_current_positions_from_observation = exclude_current_positions_from_observation + self._floor_geom_id = None self._foot_geom_id = None + self.contact_with_floor = False self.init_floor_contact = False self.has_left_floor = False self.contact_dist = None + xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy, healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale, exclude_current_positions_from_observation) def step(self, action): + self._steps += 1 + + self._floor_geom_id = self.model.geom_name2id('floor') + self._foot_geom_id = self.model.geom_name2id('foot_geom') - self.current_step += 1 self.do_simulation(action, self.frame_skip) + height_after = self.get_body_com("torso")[2] - # site_pos_after = self.sim.data.site_xpos[self.model.site_name2id('foot_site')].copy() - site_pos_after = self.get_body_com('foot_site') + site_pos_after = self.data.get_site_xpos('foot_site') self.max_height = max(height_after, self.max_height) + has_floor_contact = self._is_floor_foot_contact() if not self.contact_with_floor else False + + if not self.init_floor_contact: + self.init_floor_contact = has_floor_contact + if self.init_floor_contact and not self.has_left_floor: + self.has_left_floor = not has_floor_contact + if not self.contact_with_floor and self.has_left_floor: + self.contact_with_floor = has_floor_contact + ctrl_cost = self.control_cost(action) costs = ctrl_cost done = False + goal_dist = np.linalg.norm(site_pos_after - np.array([self.goal, 0, 0])) + + if self.contact_dist is None and self.contact_with_floor: + self.contact_dist = goal_dist + rewards = 0 - if self.current_step >= self.max_episode_steps: - hight_goal_distance = -10 * np.linalg.norm(self.max_height - self.goal) if self.context else self.max_height - healthy_reward = 0 if self.context else self.healthy_reward * 2 # self.current_step - height_reward = self._forward_reward_weight * hight_goal_distance # maybe move reward calculation into if structure and define two different _forward_reward_weight variables for context and episodic seperatley - rewards = height_reward + healthy_reward + if self._steps >= MAX_EPISODE_STEPS_HOPPERJUMP: + # healthy_reward = 0 if self.context else self.healthy_reward * self._steps + healthy_reward = self.healthy_reward * 2 # * self._steps + contact_dist = self.contact_dist if self.contact_dist is not None else 5 + dist_reward = self._forward_reward_weight * (-3 * goal_dist + 10 * self.max_height - 2 * contact_dist) + rewards = dist_reward + healthy_reward observation = self._get_obs() reward = rewards - costs - - info = { - 'height': height_after, - 'x_pos': site_pos_after, - 'max_height': self.max_height, - 'height_rew': self.max_height, - 'healthy_reward': self.healthy_reward * 2, - 'healthy': self.is_healthy - } - + info = dict( + height=height_after, + x_pos=site_pos_after, + max_height=self.max_height, + goal=self.goal, + goal_dist=goal_dist, + height_rew=self.max_height, + healthy_reward=self.healthy_reward * 2, + healthy=self.is_healthy, + contact_dist=self.contact_dist if self.contact_dist is not None else 0 + ) return observation, reward, done, info def _get_obs(self): return np.append(super()._get_obs(), self.goal) - def reset(self): + def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None, ): self.goal = self.np_random.uniform(1.4, 2.16, 1)[0] # 1.3 2.3 self.max_height = 0 - self.current_step = 0 + self._steps = 0 return super().reset() # overwrite reset_model to make it deterministic @@ -106,11 +125,13 @@ class ALRHopperJumpEnv(HopperEnv): self.contact_dist = None return observation - def _contact_checker(self, id_1, id_2): - for coni in range(0, self.sim.data.ncon): - con = self.sim.data.contact[coni] - collision = con.geom1 == id_1 and con.geom2 == id_2 - collision_trans = con.geom1 == id_2 and con.geom2 == id_1 + def _is_floor_foot_contact(self): + floor_geom_id = self.model.geom_name2id('floor') + foot_geom_id = self.model.geom_name2id('foot_geom') + for i in range(self.data.ncon): + contact = self.data.contact[i] + collision = contact.geom1 == floor_geom_id and contact.geom2 == foot_geom_id + collision_trans = contact.geom1 == foot_geom_id and contact.geom2 == floor_geom_id if collision or collision_trans: return True return False @@ -122,7 +143,7 @@ class ALRHopperXYJumpEnv(ALRHopperJumpEnv): self._floor_geom_id = self.model.geom_name2id('floor') self._foot_geom_id = self.model.geom_name2id('foot_geom') - self.current_step += 1 + self._steps += 1 self.do_simulation(action, self.frame_skip) height_after = self.get_body_com("torso")[2] site_pos_after = self.sim.data.site_xpos[self.model.site_name2id('foot_site')].copy() @@ -133,8 +154,8 @@ class ALRHopperXYJumpEnv(ALRHopperJumpEnv): # self.has_left_floor = not floor_contact if self.init_floor_contact and not self.has_left_floor else self.has_left_floor # self.contact_with_floor = floor_contact if not self.contact_with_floor and self.has_left_floor else self.contact_with_floor - floor_contact = self._contact_checker(self._floor_geom_id, - self._foot_geom_id) if not self.contact_with_floor else False + floor_contact = self._is_floor_foot_contact(self._floor_geom_id, + self._foot_geom_id) if not self.contact_with_floor else False if not self.init_floor_contact: self.init_floor_contact = floor_contact if self.init_floor_contact and not self.has_left_floor: @@ -151,9 +172,9 @@ class ALRHopperXYJumpEnv(ALRHopperJumpEnv): done = False goal_dist = np.linalg.norm(site_pos_after - np.array([self.goal, 0, 0])) rewards = 0 - if self.current_step >= self.max_episode_steps: - # healthy_reward = 0 if self.context else self.healthy_reward * self.current_step - healthy_reward = self.healthy_reward * 2 # * self.current_step + if self._steps >= self.max_episode_steps: + # healthy_reward = 0 if self.context else self.healthy_reward * self._steps + healthy_reward = self.healthy_reward * 2 # * self._steps contact_dist = self.contact_dist if self.contact_dist is not None else 5 dist_reward = self._forward_reward_weight * (-3 * goal_dist + 10 * self.max_height - 2 * contact_dist) rewards = dist_reward + healthy_reward @@ -170,7 +191,7 @@ class ALRHopperXYJumpEnv(ALRHopperJumpEnv): 'healthy_reward': self.healthy_reward * 2, 'healthy': self.is_healthy, 'contact_dist': self.contact_dist if self.contact_dist is not None else 0 - } + } return observation, reward, done, info def reset_model(self): @@ -242,7 +263,7 @@ class ALRHopperXYJumpEnvStepBased(ALRHopperXYJumpEnv): height_scale=10, dist_scale=3, healthy_scale=2 - ): + ): self.height_scale = height_scale self.dist_scale = dist_scale self.healthy_scale = healthy_scale @@ -254,7 +275,7 @@ class ALRHopperXYJumpEnvStepBased(ALRHopperXYJumpEnv): self._floor_geom_id = self.model.geom_name2id('floor') self._foot_geom_id = self.model.geom_name2id('foot_geom') - self.current_step += 1 + self._steps += 1 self.do_simulation(action, self.frame_skip) height_after = self.get_body_com("torso")[2] site_pos_after = self.sim.data.site_xpos[self.model.site_name2id('foot_site')].copy() @@ -273,8 +294,8 @@ class ALRHopperXYJumpEnvStepBased(ALRHopperXYJumpEnv): ########################################################### # This is only for logging the distance to goal when first having the contact ########################################################## - floor_contact = self._contact_checker(self._floor_geom_id, - self._foot_geom_id) if not self.contact_with_floor else False + floor_contact = self._is_floor_foot_contact(self._floor_geom_id, + self._foot_geom_id) if not self.contact_with_floor else False if not self.init_floor_contact: self.init_floor_contact = floor_contact if self.init_floor_contact and not self.has_left_floor: @@ -295,33 +316,5 @@ class ALRHopperXYJumpEnvStepBased(ALRHopperXYJumpEnv): 'healthy_reward': self.healthy_reward * self.healthy_reward, 'healthy': self.is_healthy, 'contact_dist': self.contact_dist if self.contact_dist is not None else 0 - } + } return observation, reward, done, info - - -if __name__ == '__main__': - render_mode = "human" # "human" or "partial" or "final" - # env = ALRHopperJumpEnv() - # env = ALRHopperXYJumpEnv() - np.random.seed(0) - env = ALRHopperXYJumpEnvStepBased() - env.seed(0) - # env = ALRHopperJumpRndmPosEnv() - obs = env.reset() - - for k in range(1000): - obs = env.reset() - print('observation :', obs[:]) - for i in range(200): - # objective.load_result("/tmp/cma") - # test with random actions - ac = env.action_space.sample() - obs, rew, d, info = env.step(ac) - # if i % 10 == 0: - # env.render(mode=render_mode) - env.render(mode=render_mode) - if d: - print('After ', i, ' steps, done: ', d) - env.reset() - - env.close() diff --git a/alr_envs/alr/mujoco/hopper_jump/mp_wrapper.py b/alr_envs/alr/mujoco/hopper_jump/mp_wrapper.py index e3279aa..c7b16db 100644 --- a/alr_envs/alr/mujoco/hopper_jump/mp_wrapper.py +++ b/alr_envs/alr/mujoco/hopper_jump/mp_wrapper.py @@ -1,57 +1,25 @@ -from typing import Tuple, Union +from typing import Union, Tuple import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): - @property - def context_mask(self) -> np.ndarray: + + # Random x goal + random init pos + def context_mask(self): return np.hstack([ - [False] * (5 + int(not self.exclude_current_positions_from_observation)), # position + [False] * (2 + int(not self.exclude_current_positions_from_observation)), # position + [True] * 3, # set to true if randomize initial pos [False] * 6, # velocity [True] ]) @property - def current_pos(self) -> Union[float, int, np.ndarray]: - return self.env.sim.data.qpos[3:6].copy() + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + return self.sim.data.qpos[3:6].copy() @property def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.sim.data.qvel[3:6].copy() - - @property - def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]: - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - - @property - def dt(self) -> Union[float, int]: - return self.env.dt - - -class HighCtxtMPWrapper(MPWrapper): - @property - def active_obs(self): - return np.hstack([ - [True] * (5 + int(not self.exclude_current_positions_from_observation)), # position - [False] * 6, # velocity - [False] - ]) - - @property - def current_pos(self) -> Union[float, int, np.ndarray]: - return self.env.sim.data.qpos[3:6].copy() - - @property - def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.sim.data.qvel[3:6].copy() - - @property - def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]: - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - - @property - def dt(self) -> Union[float, int]: - return self.env.dt + return self.sim.data.qvel[3:6].copy() diff --git a/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py deleted file mode 100644 index b919b22..0000000 --- a/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py +++ /dev/null @@ -1,45 +0,0 @@ -from alr_envs.mp.black_box_wrapper import BlackBoxWrapper -from typing import Union, Tuple -import numpy as np - - -class MPWrapper(BlackBoxWrapper): - @property - def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.sim.data.qpos[3:6].copy() - - @property - def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - return self.env.sim.data.qvel[3:6].copy() - - # # random goal - # def set_active_obs(self): - # return np.hstack([ - # [False] * (5 + int(not self.env.exclude_current_positions_from_observation)), # position - # [False] * 6, # velocity - # [True] - # ]) - - # Random x goal + random init pos - def get_context_mask(self): - return np.hstack([ - [False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position - [True] * 3, # set to true if randomize initial pos - [False] * 6, # velocity - [True] - ]) - - -class NewHighCtxtMPWrapper(MPWrapper): - def get_context_mask(self): - return np.hstack([ - [False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position - [True] * 3, # set to true if randomize initial pos - [False] * 6, # velocity - [True], # goal - [False] * 3 # goal diff - ]) - - def set_context(self, context): - return self.get_observation_from_step(self.env.env.set_context(context)) - diff --git a/alr_envs/alr/mujoco/hopper_throw/hopper_throw.py b/alr_envs/alr/mujoco/hopper_throw/hopper_throw.py index 03553f2..7ae33d1 100644 --- a/alr_envs/alr/mujoco/hopper_throw/hopper_throw.py +++ b/alr_envs/alr/mujoco/hopper_throw/hopper_throw.py @@ -67,7 +67,7 @@ class ALRHopperThrowEnv(HopperEnv): info = { 'ball_pos': ball_pos_after, 'ball_pos_y': ball_pos_after_y, - 'current_step' : self.current_step, + '_steps' : self.current_step, 'goal' : self.goal, } diff --git a/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py b/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py index f5bf08d..7778e8c 100644 --- a/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py +++ b/alr_envs/alr/mujoco/hopper_throw/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py b/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py index a8cd696..01d87a4 100644 --- a/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/hopper_throw/new_mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/reacher/alr_reacher.py b/alr_envs/alr/mujoco/reacher/alr_reacher.py deleted file mode 100644 index 0699c44..0000000 --- a/alr_envs/alr/mujoco/reacher/alr_reacher.py +++ /dev/null @@ -1,152 +0,0 @@ -import os - -import numpy as np -from gym import utils -from gym.envs.mujoco import MujocoEnv - -import alr_envs.utils.utils as alr_utils - - -class ALRReacherEnv(MujocoEnv, utils.EzPickle): - def __init__(self, steps_before_reward: int = 200, n_links: int = 5, ctrl_cost_weight: int = 1, - balance: bool = False): - utils.EzPickle.__init__(**locals()) - - self._steps = 0 - self.steps_before_reward = steps_before_reward - self.n_links = n_links - - self.balance = balance - self.balance_weight = 1.0 - self.ctrl_cost_weight = ctrl_cost_weight - - self.reward_weight = 1 - if steps_before_reward == 200: - self.reward_weight = 200 - elif steps_before_reward == 50: - self.reward_weight = 50 - - if n_links == 5: - file_name = 'reacher_5links.xml' - elif n_links == 7: - file_name = 'reacher_7links.xml' - else: - raise ValueError(f"Invalid number of links {n_links}, only 5 or 7 allowed.") - - MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", file_name), 2) - - def step(self, a): - self._steps += 1 - - reward_dist = 0.0 - angular_vel = 0.0 - reward_balance = 0.0 - is_delayed = self.steps_before_reward > 0 - reward_ctrl = - np.square(a).sum() * self.ctrl_cost_weight - if self._steps >= self.steps_before_reward: - vec = self.get_body_com("fingertip") - self.get_body_com("target") - reward_dist -= self.reward_weight * np.linalg.norm(vec) - if is_delayed: - # avoid giving this penalty for normal step based case - # angular_vel -= 10 * np.linalg.norm(self.sim.data.qvel.flat[:self.n_links]) - angular_vel -= 10 * np.square(self.sim.data.qvel.flat[:self.n_links]).sum() - # if is_delayed: - # # Higher control penalty for sparse reward per timestep - # reward_ctrl *= 10 - - if self.balance: - reward_balance -= self.balance_weight * np.abs( - alr_utils.angle_normalize(np.sum(self.sim.data.qpos.flat[:self.n_links]), type="rad")) - - reward = reward_dist + reward_ctrl + angular_vel + reward_balance - self.do_simulation(a, self.frame_skip) - ob = self._get_obs() - done = False - return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl, - velocity=angular_vel, reward_balance=reward_balance, - end_effector=self.get_body_com("fingertip").copy(), - goal=self.goal if hasattr(self, "goal") else None) - - def viewer_setup(self): - self.viewer.cam.trackbodyid = 0 - - # def reset_model(self): - # qpos = self.init_qpos - # if not hasattr(self, "goal"): - # self.goal = np.array([-0.25, 0.25]) - # # self.goal = self.init_qpos.copy()[:2] + 0.05 - # qpos[-2:] = self.goal - # qvel = self.init_qvel - # qvel[-2:] = 0 - # self.set_state(qpos, qvel) - # self._steps = 0 - # - # return self._get_obs() - - def reset_model(self): - qpos = self.init_qpos.copy() - while True: - # full space - # self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) - # I Quadrant - # self.goal = self.np_random.uniform(low=0, high=self.n_links / 10, size=2) - # II Quadrant - # self.goal = np.random.uniform(low=[-self.n_links / 10, 0], high=[0, self.n_links / 10], size=2) - # II + III Quadrant - # self.goal = np.random.uniform(low=-self.n_links / 10, high=[0, self.n_links / 10], size=2) - # I + II Quadrant - self.goal = np.random.uniform(low=[-self.n_links / 10, 0], high=self.n_links, size=2) - if np.linalg.norm(self.goal) < self.n_links / 10: - break - qpos[-2:] = self.goal - qvel = self.init_qvel.copy() - qvel[-2:] = 0 - self.set_state(qpos, qvel) - self._steps = 0 - - return self._get_obs() - - # def reset_model(self): - # qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos - # while True: - # self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) - # if np.linalg.norm(self.goal) < self.n_links / 10: - # break - # qpos[-2:] = self.goal - # qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) - # qvel[-2:] = 0 - # self.set_state(qpos, qvel) - # self._steps = 0 - # - # return self._get_obs() - - def _get_obs(self): - theta = self.sim.data.qpos.flat[:self.n_links] - target = self.get_body_com("target") - return np.concatenate([ - np.cos(theta), - np.sin(theta), - target[:2], # x-y of goal position - self.sim.data.qvel.flat[:self.n_links], # angular velocity - self.get_body_com("fingertip") - target, # goal distance - [self._steps], - ]) - - -if __name__ == '__main__': - nl = 5 - render_mode = "human" # "human" or "partial" or "final" - env = ALRReacherEnv(n_links=nl) - obs = env.reset() - - for i in range(2000): - # objective.load_result("/tmp/cma") - # test with random actions - ac = env.action_space.sample() - obs, rew, d, info = env.step(ac) - if i % 10 == 0: - env.render(mode=render_mode) - if d: - env.reset() - - env.close() diff --git a/alr_envs/alr/mujoco/reacher/balancing.py b/alr_envs/alr/mujoco/reacher/balancing.py deleted file mode 100644 index 3e34298..0000000 --- a/alr_envs/alr/mujoco/reacher/balancing.py +++ /dev/null @@ -1,53 +0,0 @@ -import os - -import numpy as np -from gym import utils -from gym.envs.mujoco import mujoco_env - -import alr_envs.utils.utils as alr_utils - - -class BalancingEnv(mujoco_env.MujocoEnv, utils.EzPickle): - def __init__(self, n_links=5): - utils.EzPickle.__init__(**locals()) - - self.n_links = n_links - - if n_links == 5: - file_name = 'reacher_5links.xml' - elif n_links == 7: - file_name = 'reacher_7links.xml' - else: - raise ValueError(f"Invalid number of links {n_links}, only 5 or 7 allowed.") - - mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", file_name), 2) - - def step(self, a): - angle = alr_utils.angle_normalize(np.sum(self.sim.data.qpos.flat[:self.n_links]), type="rad") - reward = - np.abs(angle) - - self.do_simulation(a, self.frame_skip) - ob = self._get_obs() - done = False - return ob, reward, done, dict(angle=angle, end_effector=self.get_body_com("fingertip").copy()) - - def viewer_setup(self): - self.viewer.cam.trackbodyid = 1 - - def reset_model(self): - # This also generates a goal, we however do not need/use it - qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos - qpos[-2:] = 0 - qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) - qvel[-2:] = 0 - self.set_state(qpos, qvel) - - return self._get_obs() - - def _get_obs(self): - theta = self.sim.data.qpos.flat[:self.n_links] - return np.concatenate([ - np.cos(theta), - np.sin(theta), - self.sim.data.qvel.flat[:self.n_links], # this is angular velocity - ]) diff --git a/alr_envs/alr/mujoco/reacher/mp_wrapper.py b/alr_envs/alr/mujoco/reacher/mp_wrapper.py index 966be23..de33ae0 100644 --- a/alr_envs/alr/mujoco/reacher/mp_wrapper.py +++ b/alr_envs/alr/mujoco/reacher/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Union, Tuple import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/reacher/reacher.py b/alr_envs/alr/mujoco/reacher/reacher.py new file mode 100644 index 0000000..a5d34ee --- /dev/null +++ b/alr_envs/alr/mujoco/reacher/reacher.py @@ -0,0 +1,105 @@ +import os + +import numpy as np +from gym import utils +from gym.envs.mujoco import MujocoEnv + + +class ReacherEnv(MujocoEnv, utils.EzPickle): + """ + More general version of the gym mujoco Reacher environment + """ + + def __init__(self, sparse: bool = False, n_links: int = 5, ctrl_cost_weight: int = 1): + utils.EzPickle.__init__(**locals()) + + self._steps = 0 + self.n_links = n_links + + self.ctrl_cost_weight = ctrl_cost_weight + + self.sparse = sparse + self.reward_weight = 1 if not sparse else 200 + + file_name = f'reacher_{n_links}links.xml' + + MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", file_name), 2) + + def step(self, action): + self._steps += 1 + + is_reward = not self.sparse or (self.sparse and self._steps == 200) + + reward_dist = 0.0 + angular_vel = 0.0 + if is_reward: + reward_dist = self.distance_reward() + angular_vel = self.velocity_reward() + + reward_ctrl = -self.ctrl_cost_weight * np.square(action).sum() + + reward = reward_dist + reward_ctrl + angular_vel + self.do_simulation(action, self.frame_skip) + ob = self._get_obs() + done = False + + infos = dict( + reward_dist=reward_dist, + reward_ctrl=reward_ctrl, + velocity=angular_vel, + end_effector=self.get_body_com("fingertip").copy(), + goal=self.goal if hasattr(self, "goal") else None + ) + + return ob, reward, done, infos + + def distance_reward(self): + vec = self.get_body_com("fingertip") - self.get_body_com("target") + return -self.reward_weight * np.linalg.norm(vec) + + def velocity_reward(self): + return -10 * np.square(self.sim.data.qvel.flat[:self.n_links]).sum() if self.sparse else 0.0 + + def viewer_setup(self): + self.viewer.cam.trackbodyid = 0 + + def reset_model(self): + qpos = ( + # self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + + self.init_qpos.copy() + ) + while True: + # full space + self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) + # I Quadrant + # self.goal = self.np_random.uniform(low=0, high=self.n_links / 10, size=2) + # II Quadrant + # self.goal = np.random.uniform(low=[-self.n_links / 10, 0], high=[0, self.n_links / 10], size=2) + # II + III Quadrant + # self.goal = np.random.uniform(low=-self.n_links / 10, high=[0, self.n_links / 10], size=2) + # I + II Quadrant + # self.goal = np.random.uniform(low=[-self.n_links / 10, 0], high=self.n_links, size=2) + if np.linalg.norm(self.goal) < self.n_links / 10: + break + + qpos[-2:] = self.goal + qvel = ( + # self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv) + + self.init_qvel.copy() + ) + qvel[-2:] = 0 + self.set_state(qpos, qvel) + self._steps = 0 + + return self._get_obs() + + def _get_obs(self): + theta = self.sim.data.qpos.flat[:self.n_links] + target = self.get_body_com("target") + return np.concatenate([ + np.cos(theta), + np.sin(theta), + target[:2], # x-y of goal position + self.sim.data.qvel.flat[:self.n_links], # angular velocity + self.get_body_com("fingertip") - target, # goal distance + ]) diff --git a/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py b/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py index 408124a..40e4252 100644 --- a/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py +++ b/alr_envs/alr/mujoco/table_tennis/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/walker_2d_jump/mp_wrapper.py b/alr_envs/alr/mujoco/walker_2d_jump/mp_wrapper.py index 0c2dba5..5e9d0eb 100644 --- a/alr_envs/alr/mujoco/walker_2d_jump/mp_wrapper.py +++ b/alr_envs/alr/mujoco/walker_2d_jump/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py index 96b0739..b2bfde9 100644 --- a/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/walker_2d_jump/new_mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/mp/__init__.py b/alr_envs/black_box/__init__.py similarity index 100% rename from alr_envs/mp/__init__.py rename to alr_envs/black_box/__init__.py diff --git a/alr_envs/mp/black_box_wrapper.py b/alr_envs/black_box/black_box_wrapper.py similarity index 94% rename from alr_envs/mp/black_box_wrapper.py rename to alr_envs/black_box/black_box_wrapper.py index f1ba41f..0c10ef8 100644 --- a/alr_envs/mp/black_box_wrapper.py +++ b/alr_envs/black_box/black_box_wrapper.py @@ -6,8 +6,8 @@ import numpy as np from gym import spaces from mp_pytorch.mp.mp_interfaces import MPInterface -from alr_envs.mp.controllers.base_controller import BaseController -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.controller.base_controller import BaseController +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.utils.utils import get_numpy @@ -15,10 +15,14 @@ class BlackBoxWrapper(gym.ObservationWrapper): def __init__(self, env: RawInterfaceWrapper, - trajectory_generator: MPInterface, tracking_controller: BaseController, - duration: float, verbose: int = 1, learn_sub_trajectories: bool = False, + trajectory_generator: MPInterface, + tracking_controller: BaseController, + duration: float, + verbose: int = 1, + learn_sub_trajectories: bool = False, replanning_schedule: Union[None, callable] = None, - reward_aggregation: callable = np.sum): + reward_aggregation: callable = np.sum + ): """ gym.Wrapper for leveraging a black box approach with a trajectory generator. diff --git a/alr_envs/mp/controllers/__init__.py b/alr_envs/black_box/controller/__init__.py similarity index 100% rename from alr_envs/mp/controllers/__init__.py rename to alr_envs/black_box/controller/__init__.py diff --git a/alr_envs/mp/controllers/base_controller.py b/alr_envs/black_box/controller/base_controller.py similarity index 100% rename from alr_envs/mp/controllers/base_controller.py rename to alr_envs/black_box/controller/base_controller.py diff --git a/alr_envs/mp/controllers/controller_factory.py b/alr_envs/black_box/controller/controller_factory.py similarity index 60% rename from alr_envs/mp/controllers/controller_factory.py rename to alr_envs/black_box/controller/controller_factory.py index 1c1cfec..6ef7960 100644 --- a/alr_envs/mp/controllers/controller_factory.py +++ b/alr_envs/black_box/controller/controller_factory.py @@ -1,7 +1,7 @@ -from alr_envs.mp.controllers.meta_world_controller import MetaWorldController -from alr_envs.mp.controllers.pd_controller import PDController -from alr_envs.mp.controllers.vel_controller import VelController -from alr_envs.mp.controllers.pos_controller import PosController +from alr_envs.black_box.controller.meta_world_controller import MetaWorldController +from alr_envs.black_box.controller.pd_controller import PDController +from alr_envs.black_box.controller.vel_controller import VelController +from alr_envs.black_box.controller.pos_controller import PosController ALL_TYPES = ["motor", "velocity", "position", "metaworld"] diff --git a/alr_envs/mp/controllers/meta_world_controller.py b/alr_envs/black_box/controller/meta_world_controller.py similarity index 92% rename from alr_envs/mp/controllers/meta_world_controller.py rename to alr_envs/black_box/controller/meta_world_controller.py index 5747f9e..296ea3a 100644 --- a/alr_envs/mp/controllers/meta_world_controller.py +++ b/alr_envs/black_box/controller/meta_world_controller.py @@ -1,6 +1,6 @@ import numpy as np -from alr_envs.mp.controllers.base_controller import BaseController +from alr_envs.black_box.controller.base_controller import BaseController class MetaWorldController(BaseController): diff --git a/alr_envs/mp/controllers/pd_controller.py b/alr_envs/black_box/controller/pd_controller.py similarity index 93% rename from alr_envs/mp/controllers/pd_controller.py rename to alr_envs/black_box/controller/pd_controller.py index 140aeee..ab21444 100644 --- a/alr_envs/mp/controllers/pd_controller.py +++ b/alr_envs/black_box/controller/pd_controller.py @@ -1,6 +1,6 @@ from typing import Union, Tuple -from alr_envs.mp.controllers.base_controller import BaseController +from alr_envs.black_box.controller.base_controller import BaseController class PDController(BaseController): diff --git a/alr_envs/mp/controllers/pos_controller.py b/alr_envs/black_box/controller/pos_controller.py similarity index 77% rename from alr_envs/mp/controllers/pos_controller.py rename to alr_envs/black_box/controller/pos_controller.py index 5570307..3f3526a 100644 --- a/alr_envs/mp/controllers/pos_controller.py +++ b/alr_envs/black_box/controller/pos_controller.py @@ -1,4 +1,4 @@ -from alr_envs.mp.controllers.base_controller import BaseController +from alr_envs.black_box.controller.base_controller import BaseController class PosController(BaseController): diff --git a/alr_envs/mp/controllers/vel_controller.py b/alr_envs/black_box/controller/vel_controller.py similarity index 77% rename from alr_envs/mp/controllers/vel_controller.py rename to alr_envs/black_box/controller/vel_controller.py index 67bab2a..2134207 100644 --- a/alr_envs/mp/controllers/vel_controller.py +++ b/alr_envs/black_box/controller/vel_controller.py @@ -1,4 +1,4 @@ -from alr_envs.mp.controllers.base_controller import BaseController +from alr_envs.black_box.controller.base_controller import BaseController class VelController(BaseController): diff --git a/alr_envs/black_box/factory/__init__.py b/alr_envs/black_box/factory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alr_envs/mp/basis_generator_factory.py b/alr_envs/black_box/factory/basis_generator_factory.py similarity index 100% rename from alr_envs/mp/basis_generator_factory.py rename to alr_envs/black_box/factory/basis_generator_factory.py diff --git a/alr_envs/mp/phase_generator_factory.py b/alr_envs/black_box/factory/phase_generator_factory.py similarity index 93% rename from alr_envs/mp/phase_generator_factory.py rename to alr_envs/black_box/factory/phase_generator_factory.py index 45cfdc1..ca0dd84 100644 --- a/alr_envs/mp/phase_generator_factory.py +++ b/alr_envs/black_box/factory/phase_generator_factory.py @@ -17,4 +17,4 @@ def get_phase_generator(phase_generator_type, **kwargs): return SmoothPhaseGenerator(**kwargs) else: raise ValueError(f"Specified phase generator type {phase_generator_type} not supported, " - f"please choose one of {ALL_TYPES}.") \ No newline at end of file + f"please choose one of {ALL_TYPES}.") diff --git a/alr_envs/mp/mp_factory.py b/alr_envs/black_box/factory/trajectory_generator_factory.py similarity index 91% rename from alr_envs/mp/mp_factory.py rename to alr_envs/black_box/factory/trajectory_generator_factory.py index d2c5460..784513e 100644 --- a/alr_envs/mp/mp_factory.py +++ b/alr_envs/black_box/factory/trajectory_generator_factory.py @@ -9,7 +9,7 @@ ALL_TYPES = ["promp", "dmp", "idmp"] def get_trajectory_generator( trajectory_generator_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs - ): +): trajectory_generator_type = trajectory_generator_type.lower() if trajectory_generator_type == "promp": return ProMP(basis_generator, action_dim, **kwargs) @@ -19,4 +19,4 @@ def get_trajectory_generator( return IDMP(basis_generator, action_dim, **kwargs) else: raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, " - f"please choose one of {ALL_TYPES}.") \ No newline at end of file + f"please choose one of {ALL_TYPES}.") diff --git a/alr_envs/mp/raw_interface_wrapper.py b/alr_envs/black_box/raw_interface_wrapper.py similarity index 100% rename from alr_envs/mp/raw_interface_wrapper.py rename to alr_envs/black_box/raw_interface_wrapper.py diff --git a/alr_envs/dmc/manipulation/reach_site/mp_wrapper.py b/alr_envs/dmc/manipulation/reach_site/mp_wrapper.py index 6d5029e..d918edc 100644 --- a/alr_envs/dmc/manipulation/reach_site/mp_wrapper.py +++ b/alr_envs/dmc/manipulation/reach_site/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py b/alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py index 9687bed..b3f882e 100644 --- a/alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py +++ b/alr_envs/dmc/suite/ball_in_cup/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/dmc/suite/cartpole/mp_wrapper.py b/alr_envs/dmc/suite/cartpole/mp_wrapper.py index 3f16d24..6cc0687 100644 --- a/alr_envs/dmc/suite/cartpole/mp_wrapper.py +++ b/alr_envs/dmc/suite/cartpole/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/dmc/suite/reacher/mp_wrapper.py b/alr_envs/dmc/suite/reacher/mp_wrapper.py index ac857c1..82e4da8 100644 --- a/alr_envs/dmc/suite/reacher/mp_wrapper.py +++ b/alr_envs/dmc/suite/reacher/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/meta/goal_change_mp_wrapper.py b/alr_envs/meta/goal_change_mp_wrapper.py index 17495da..e628a0c 100644 --- a/alr_envs/meta/goal_change_mp_wrapper.py +++ b/alr_envs/meta/goal_change_mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/meta/goal_endeffector_change_mp_wrapper.py b/alr_envs/meta/goal_endeffector_change_mp_wrapper.py index 3a6ad1c..1a128e7 100644 --- a/alr_envs/meta/goal_endeffector_change_mp_wrapper.py +++ b/alr_envs/meta/goal_endeffector_change_mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/meta/goal_object_change_mp_wrapper.py b/alr_envs/meta/goal_object_change_mp_wrapper.py index 97c64b8..1a6f57e 100644 --- a/alr_envs/meta/goal_object_change_mp_wrapper.py +++ b/alr_envs/meta/goal_object_change_mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/meta/object_change_mp_wrapper.py b/alr_envs/meta/object_change_mp_wrapper.py index f832c9f..07e88dc 100644 --- a/alr_envs/meta/object_change_mp_wrapper.py +++ b/alr_envs/meta/object_change_mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Tuple, Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py b/alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py index 189563c..db9f1f2 100644 --- a/alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py +++ b/alr_envs/open_ai/classic_control/continuous_mountain_car/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py b/alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py index 9d627b6..1e02f10 100644 --- a/alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py +++ b/alr_envs/open_ai/mujoco/reacher_v2/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/open_ai/robotics/fetch/mp_wrapper.py b/alr_envs/open_ai/robotics/fetch/mp_wrapper.py index 7a7bed6..331aa40 100644 --- a/alr_envs/open_ai/robotics/fetch/mp_wrapper.py +++ b/alr_envs/open_ai/robotics/fetch/mp_wrapper.py @@ -2,7 +2,7 @@ from typing import Union import numpy as np -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): diff --git a/alr_envs/utils/make_env_helpers.py b/alr_envs/utils/make_env_helpers.py index b5587a7..3a8eec5 100644 --- a/alr_envs/utils/make_env_helpers.py +++ b/alr_envs/utils/make_env_helpers.py @@ -7,12 +7,12 @@ import numpy as np from gym.envs.registration import EnvSpec, registry from gym.wrappers import TimeAwareObservation -from alr_envs.mp.basis_generator_factory import get_basis_generator -from alr_envs.mp.black_box_wrapper import BlackBoxWrapper -from alr_envs.mp.controllers.controller_factory import get_controller -from alr_envs.mp.mp_factory import get_trajectory_generator -from alr_envs.mp.phase_generator_factory import get_phase_generator -from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper +from alr_envs.black_box.factory.basis_generator_factory import get_basis_generator +from alr_envs.black_box.black_box_wrapper import BlackBoxWrapper +from alr_envs.black_box.controller.controller_factory import get_controller +from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator +from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator +from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper from alr_envs.utils.utils import nested_update @@ -46,6 +46,7 @@ def make(env_id, seed, **kwargs): spec = registry.get(env_id) # This access is required to allow for nested dict updates all_kwargs = deepcopy(spec._kwargs) + # TODO append wrapper here nested_update(all_kwargs, **kwargs) return _make(env_id, seed, **all_kwargs) @@ -224,8 +225,8 @@ def make_bb_env_helper(**kwargs): seed = kwargs.pop("seed", None) wrappers = kwargs.pop("wrappers") - traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {}) black_box_kwargs = kwargs.pop('black_box_kwargs', {}) + traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {}) contr_kwargs = kwargs.pop("controller_kwargs", {}) phase_kwargs = kwargs.pop("phase_generator_kwargs", {}) basis_kwargs = kwargs.pop("basis_generator_kwargs", {})