diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index 3b26345..e6c0c80 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -572,6 +572,36 @@ for _v in _versions: ) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) +for _v in _versions: + _name = _v.split("-") + _env_id = f'{_name[0]}ProDMP-{_name[1]}' + kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) + if _v == 'TableTennisWind-v0': + kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TTVelObs_MPWrapper) + else: + kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TT_MPWrapper) + kwargs_dict_tt_prodmp['name'] = _v + kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]) + kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) + kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = False + kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0 + kwargs_dict_tt_prodmp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5] + kwargs_dict_tt_prodmp['phase_generator_kwargs']['delay_bound'] = [0.05, 0.15] + kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_tau'] = True + kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True + kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2 + kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25. + kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 + kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 + # kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3 + # kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0 + register( + id=_env_id, + entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', + kwargs=kwargs_dict_tt_prodmp + ) + ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) + for _v in _versions: _name = _v.split("-") _env_id = f'{_name[0]}ReplanProDMP-{_name[1]}'