ported mp_config for mujoco/table_tennis
This commit is contained in:
		
							parent
							
								
									64e6ac5323
								
							
						
					
					
						commit
						9ba3fa9dbc
					
				@ -7,6 +7,53 @@ from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, j
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TT_MPWrapper(RawInterfaceWrapper):
 | 
			
		||||
    mp_config = {
 | 
			
		||||
        'ProMP': {
 | 
			
		||||
            'phase_generator_kwargs': {
 | 
			
		||||
                'learn_tau': False,
 | 
			
		||||
                'learn_delay': False,
 | 
			
		||||
                'tau_bound': [0.8, 1.5],
 | 
			
		||||
                'delay_bound': [0.05, 0.15],
 | 
			
		||||
            },
 | 
			
		||||
            'controller_kwargs': {
 | 
			
		||||
                'p_gains': 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]),
 | 
			
		||||
                'd_gains': 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]),
 | 
			
		||||
            },
 | 
			
		||||
            'basis_generator_kwargs': {
 | 
			
		||||
                'num_basis': 3,
 | 
			
		||||
                'num_basis_zero_start': 1,
 | 
			
		||||
                'num_basis_zero_goal': 1,
 | 
			
		||||
            },
 | 
			
		||||
            'black_box_kwargs': {
 | 
			
		||||
                'verbose': 2,
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        'DMP': {},
 | 
			
		||||
        'ProDMP': {
 | 
			
		||||
            'phase_generator_kwargs': {
 | 
			
		||||
                'learn_tau': True,
 | 
			
		||||
                'learn_delay': True,
 | 
			
		||||
                'tau_bound': [0.8, 1.5],
 | 
			
		||||
                'delay_bound': [0.05, 0.15],
 | 
			
		||||
                'alpha_phase': 3,
 | 
			
		||||
            },
 | 
			
		||||
            'controller_kwargs': {
 | 
			
		||||
                'p_gains': 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]),
 | 
			
		||||
                'd_gains': 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]),
 | 
			
		||||
            },
 | 
			
		||||
            'basis_generator_kwargs': {
 | 
			
		||||
                'num_basis': 3,
 | 
			
		||||
                'alpha': 25,
 | 
			
		||||
                'basis_bandwidth_factor': 3,
 | 
			
		||||
            },
 | 
			
		||||
            'trajectory_generator_kwargs': {
 | 
			
		||||
                'weights_scale': 0.7,
 | 
			
		||||
                'auto_scale_basis': True,
 | 
			
		||||
                'relative_goal': True,
 | 
			
		||||
                'disable_goal': True,
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # Random x goal + random init pos
 | 
			
		||||
    @property
 | 
			
		||||
@ -39,7 +86,58 @@ class TT_MPWrapper(RawInterfaceWrapper):
 | 
			
		||||
                              return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
 | 
			
		||||
        return self.get_invalid_traj_step_return(action, pos_traj, return_contextual_obs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TT_MPWrapper_Replan(TT_MPWrapper):
 | 
			
		||||
    mp_config = {
 | 
			
		||||
        'ProMP': {},
 | 
			
		||||
        'DMP': {},
 | 
			
		||||
        'ProDMP': {
 | 
			
		||||
            'phase_generator_kwargs': {
 | 
			
		||||
                'learn_tau': True,
 | 
			
		||||
                'learn_delay': True,
 | 
			
		||||
                'tau_bound': [0.8, 1.5],
 | 
			
		||||
                'delay_bound': [0.05, 0.15],
 | 
			
		||||
                'alpha_phase': 3,
 | 
			
		||||
            },
 | 
			
		||||
            'controller_kwargs': {
 | 
			
		||||
                'p_gains': 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]),
 | 
			
		||||
                'd_gains': 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]),
 | 
			
		||||
            },
 | 
			
		||||
            'basis_generator_kwargs': {
 | 
			
		||||
                'num_basis': 2,
 | 
			
		||||
                'alpha': 25,
 | 
			
		||||
                'basis_bandwidth_factor': 3,
 | 
			
		||||
            },
 | 
			
		||||
            'trajectory_generator_kwargs': {
 | 
			
		||||
                'auto_scale_basis': True,
 | 
			
		||||
                'goal_offset': 1.0,
 | 
			
		||||
            },
 | 
			
		||||
            'black_box_kwargs': {
 | 
			
		||||
                'max_planning_times': 3,
 | 
			
		||||
                'replanning_schedule': lambda pos, vel, obs, action, t: t % 50 == 0,
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TTVelObs_MPWrapper(TT_MPWrapper):
 | 
			
		||||
    # Will inherit mp_config from TT_MPWrapper
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def context_mask(self):
 | 
			
		||||
        return np.hstack([
 | 
			
		||||
            [False] * 7,  # joints position
 | 
			
		||||
            [False] * 7,  # joints velocity
 | 
			
		||||
            [True] * 2,  # position ball x, y
 | 
			
		||||
            [False] * 1,  # position ball z
 | 
			
		||||
            [True] * 3,    # velocity ball x, y, z
 | 
			
		||||
            [True] * 2,  # target landing position
 | 
			
		||||
            # [True] * 1,  # time
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TTVelObs_MPWrapper_Replan(TT_MPWrapper_Replan):
 | 
			
		||||
    # Will inherit mp_config from TT_MPWrapper_Replan
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def context_mask(self):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user