From 1b061b2a378f45fa0ff99ec4d72bae4ecfdef256 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 20 Jul 2023 11:45:32 +0200 Subject: [PATCH] ported mp_config for mujoco/box_pushing --- fancy_gym/envs/mujoco/box_pushing/__init__.py | 2 +- .../envs/mujoco/box_pushing/mp_wrapper.py | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/fancy_gym/envs/mujoco/box_pushing/__init__.py b/fancy_gym/envs/mujoco/box_pushing/__init__.py index c5e6d2f..d683024 100644 --- a/fancy_gym/envs/mujoco/box_pushing/__init__.py +++ b/fancy_gym/envs/mujoco/box_pushing/__init__.py @@ -1 +1 @@ -from .mp_wrapper import MPWrapper +from .mp_wrapper import MPWrapper, ReplanMPWrapper diff --git a/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py b/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py index 09b2d65..03121f9 100644 --- a/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py +++ b/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py @@ -6,6 +6,19 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): + mp_config = { + 'ProMP': { + 'controller_kwargs': { + 'p_gains': 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]), + 'd_gains': 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]), + }, + 'basis_generator_kwargs': { + 'basis_bandwidth_factor': 2 # 3.5, 4 to try + } + }, + 'DMP': {}, + 'ProDMP': {}, + } # Random x goal + random init pos @property @@ -27,3 +40,33 @@ class MPWrapper(RawInterfaceWrapper): @property def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: return self.data.qvel[:7].copy() + + +class ReplanMPWrapper(MPWrapper): + mp_config = { + 'ProMP': {}, + 'DMP': {}, + 'ProDMP': { + 'controller_kwargs': { + 'p_gains': 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]), + 'd_gains': 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]), + }, + 'trajectory_generator_kwargs': { + 'weights_scale': 0.3, + 'goal_scale': 0.3, + 'auto_scale_basis': True, + 'goal_offset': 1.0, + 'disable_goal': True, + }, + 'basis_generator_kwargs': { + 'num_basis': 5, + 'basis_bandwidth_factor': 3, + 'alpha_phase': 3, + }, + 'black_box_kwargs': { + 'max_planning_times': 4, + 'replanning_schedule': lambda pos, vel, obs, action, t: t % 25 == 0, + 'condition_on_desired': True, + } + } + }