diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index 35e1ccf..3b26345 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -77,6 +77,7 @@ DEFAULT_BB_DICT_ProDMP = { }, "phase_generator_kwargs": { 'phase_generator_type': 'exp', + 'tau': 1.5, }, "controller_kwargs": { 'controller_type': 'motor', @@ -529,7 +530,8 @@ for _v in _versions: kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = True kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0 - kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4 + kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['disable_goal'] = True + kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 5 kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4 @@ -560,7 +562,7 @@ for _v in _versions: kwargs_dict_tt_promp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5] kwargs_dict_tt_promp['phase_generator_kwargs']['delay_bound'] = [0.05, 0.15] kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3 - kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2 + kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 1 kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1 kwargs_dict_tt_promp['black_box_kwargs']['verbose'] = 2 register( diff --git a/fancy_gym/examples/mp_params_tuning.py b/fancy_gym/examples/mp_params_tuning.py new file mode 100644 index 0000000..a0ac386 --- /dev/null +++ b/fancy_gym/examples/mp_params_tuning.py @@ -0,0 +1,10 @@ +import fancy_gym + +def compare_bases_shape(env1_id, env2_id): + env1 = fancy_gym.make(env1_id, seed=0) + env1.traj_gen.show_scaled_basis(plot=True) + env2 = fancy_gym.make(env2_id, seed=0) + env2.traj_gen.show_scaled_basis(plot=True) + return +if __name__ == '__main__': + compare_bases_shape("BoxPushingDenseReplanProDMP-v0", "BoxPushingDenseProMP-v0") \ No newline at end of file