From 819fca1b2e191afdf641a932633ed9b7d9a2f0ba Mon Sep 17 00:00:00 2001 From: Onur Date: Thu, 7 Jul 2022 09:39:20 +0200 Subject: [PATCH] new reacher mp wrapper for Philipp --- alr_envs/alr/__init__.py | 40 ++++++++++++------- alr_envs/alr/mujoco/reacher/__init__.py | 3 +- alr_envs/alr/mujoco/reacher/alr_reacher.py | 2 +- alr_envs/alr/mujoco/reacher/new_mp_wrapper.py | 4 +- 4 files changed, 31 insertions(+), 18 deletions(-) diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 09f533a..b49776a 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -668,21 +668,33 @@ for _v in _versions: _env_id = f'{_name[0]}ProMP-{_name[1]}' register( id=_env_id, - entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', + entry_point='alr_envs.utils.make_env_helpers:make_mp_env_helper', kwargs={ "name": f"alr_envs:{_v}", - "wrappers": [mujoco.reacher.MPWrapper], - "mp_kwargs": { - "num_dof": 5 if "long" not in _v.lower() else 7, - "num_basis": 2, - "duration": 4, - "policy_type": "motor", - "weights_scale": 5, - "zero_start": True, - "policy_kwargs": { - "p_gains": 1, - "d_gains": 0.1 - } + "wrappers": [mujoco.reacher.NewMPWrapper], + "ep_wrapper_kwargs": { + "weight_scale": 1 + }, + "movement_primitives_kwargs": { + 'movement_primitives_type': 'promp', + 'action_dim': 5 if "long" not in _v.lower() else 7 + }, + "phase_generator_kwargs": { + 'phase_generator_type': 'linear', + 'delay': 0, + 'tau': 4, # initial value + 'learn_tau': False, + 'learn_delay': False + }, + "controller_kwargs": { + 'controller_type': 'motor', + "p_gains": 1, + "d_gains": 0.1 + }, + "basis_generator_kwargs": { + 'basis_generator_type': 'zero_rbf', + 'num_basis': 2, + 'num_basis_zero_start': 1 } } ) @@ -1275,4 +1287,4 @@ for _v in _versions: } } ) - ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) \ No newline at end of file + ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) diff --git a/alr_envs/alr/mujoco/reacher/__init__.py b/alr_envs/alr/mujoco/reacher/__init__.py index 989b5a9..5d15867 100644 --- a/alr_envs/alr/mujoco/reacher/__init__.py +++ b/alr_envs/alr/mujoco/reacher/__init__.py @@ -1 +1,2 @@ -from .mp_wrapper import MPWrapper \ No newline at end of file +from .mp_wrapper import MPWrapper +from .new_mp_wrapper import NewMPWrapper diff --git a/alr_envs/alr/mujoco/reacher/alr_reacher.py b/alr_envs/alr/mujoco/reacher/alr_reacher.py index c12352a..0699c44 100644 --- a/alr_envs/alr/mujoco/reacher/alr_reacher.py +++ b/alr_envs/alr/mujoco/reacher/alr_reacher.py @@ -149,4 +149,4 @@ if __name__ == '__main__': if d: env.reset() - env.close() \ No newline at end of file + env.close() diff --git a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py index 02dc1d8..bf59380 100644 --- a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py @@ -3,7 +3,7 @@ from typing import Union, Tuple import numpy as np -class MPWrapper(EpisodicWrapper): +class NewMPWrapper(EpisodicWrapper): @property def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: @@ -21,4 +21,4 @@ class MPWrapper(EpisodicWrapper): [False] * 3, # goal distance # self.get_body_com("target"), # only return target to make problem harder [False], # step - ]) \ No newline at end of file + ])