diff --git a/fancy_gym/envs/mujoco/beerpong/mp_wrapper.py b/fancy_gym/envs/mujoco/beerpong/mp_wrapper.py index 17a11e1..452ee05 100644 --- a/fancy_gym/envs/mujoco/beerpong/mp_wrapper.py +++ b/fancy_gym/envs/mujoco/beerpong/mp_wrapper.py @@ -6,6 +6,23 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): + mp_config = { + 'ProMP': { + 'phase_generator_kwargs': { + 'learn_tau': True + }, + 'controller_kwargs': { + 'p_gains': np.array([1.5, 5, 2.55, 3, 2., 2, 1.25]), + 'd_gains': np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125]), + }, + 'basis_generator_kwargs': { + 'num_basis': 2, + 'num_basis_zero_start': 2, + }, + }, + 'DMP': {}, + 'ProDMP': {}, + } @property def context_mask(self) -> np.ndarray: @@ -39,3 +56,23 @@ class MPWrapper(RawInterfaceWrapper): xyz[-1] = 0.840 self.model.body_pos[self.cup_table_id] = xyz return self.get_observation_from_step(self.get_obs()) + + +class MPWrapper_FixedRelease(MPWrapper): + mp_config = { + 'ProMP': { + 'phase_generator_kwargs': { + 'tau': 0.62, + }, + 'controller_kwargs': { + 'p_gains': np.array([1.5, 5, 2.55, 3, 2., 2, 1.25]), + 'd_gains': np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125]), + }, + 'basis_generator_kwargs': { + 'num_basis': 2, + 'num_basis_zero_start': 2, + }, + }, + 'DMP': {}, + 'ProDMP': {}, + }