diff --git a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py index e33ed6c..3e5a464 100644 --- a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py +++ b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py @@ -7,6 +7,53 @@ from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, j class TT_MPWrapper(RawInterfaceWrapper): + mp_config = { + 'ProMP': { + 'phase_generator_kwargs': { + 'learn_tau': False, + 'learn_delay': False, + 'tau_bound': [0.8, 1.5], + 'delay_bound': [0.05, 0.15], + }, + 'controller_kwargs': { + 'p_gains': 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]), + 'd_gains': 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]), + }, + 'basis_generator_kwargs': { + 'num_basis': 3, + 'num_basis_zero_start': 1, + 'num_basis_zero_goal': 1, + }, + 'black_box_kwargs': { + 'verbose': 2, + }, + }, + 'DMP': {}, + 'ProDMP': { + 'phase_generator_kwargs': { + 'learn_tau': True, + 'learn_delay': True, + 'tau_bound': [0.8, 1.5], + 'delay_bound': [0.05, 0.15], + 'alpha_phase': 3, + }, + 'controller_kwargs': { + 'p_gains': 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]), + 'd_gains': 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]), + }, + 'basis_generator_kwargs': { + 'num_basis': 3, + 'alpha': 25, + 'basis_bandwidth_factor': 3, + }, + 'trajectory_generator_kwargs': { + 'weights_scale': 0.7, + 'auto_scale_basis': True, + 'relative_goal': True, + 'disable_goal': True, + }, + }, + } # Random x goal + random init pos @property @@ -16,7 +63,7 @@ class TT_MPWrapper(RawInterfaceWrapper): [False] * 7, # joints velocity [True] * 2, # position ball x, y [False] * 1, # position ball z - #[True] * 3, # velocity ball x, y, z + # [True] * 3, # velocity ball x, y, z [True] * 2, # target landing position # [True] * 1, # time ]) @@ -39,7 +86,42 @@ class TT_MPWrapper(RawInterfaceWrapper): return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]: return self.get_invalid_traj_step_return(action, pos_traj, return_contextual_obs) + +class TT_MPWrapper_Replan(TT_MPWrapper): + mp_config = { + 'ProMP': {}, + 'DMP': {}, + 'ProDMP': { + 'phase_generator_kwargs': { + 'learn_tau': True, + 'learn_delay': True, + 'tau_bound': [0.8, 1.5], + 'delay_bound': [0.05, 0.15], + 'alpha_phase': 3, + }, + 'controller_kwargs': { + 'p_gains': 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]), + 'd_gains': 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]), + }, + 'basis_generator_kwargs': { + 'num_basis': 2, + 'alpha': 25, + 'basis_bandwidth_factor': 3, + }, + 'trajectory_generator_kwargs': { + 'auto_scale_basis': True, + 'goal_offset': 1.0, + }, + 'black_box_kwargs': { + 'max_planning_times': 3, + 'replanning_schedule': lambda pos, vel, obs, action, t: t % 50 == 0, + }, + }, + } + + class TTVelObs_MPWrapper(TT_MPWrapper): + # Will inherit mp_config from TT_MPWrapper @property def context_mask(self): @@ -51,4 +133,20 @@ class TTVelObs_MPWrapper(TT_MPWrapper): [True] * 3, # velocity ball x, y, z [True] * 2, # target landing position # [True] * 1, # time - ]) \ No newline at end of file + ]) + + +class TTVelObs_MPWrapper_Replan(TT_MPWrapper_Replan): + # Will inherit mp_config from TT_MPWrapper_Replan + + @property + def context_mask(self): + return np.hstack([ + [False] * 7, # joints position + [False] * 7, # joints velocity + [True] * 2, # position ball x, y + [False] * 1, # position ball z + [True] * 3, # velocity ball x, y, z + [True] * 2, # target landing position + # [True] * 1, # time + ])