learn only basis weight

This commit is contained in:
Hongyi Zhou 2022-10-26 22:54:35 +02:00
parent e49d1563fe
commit 288ebedd50
2 changed files with 8 additions and 6 deletions

View File

@ -150,7 +150,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
## tricky part, only use weights basis
weights_basis = action.reshape(-1, 7)
basis_weights = action.reshape(7, -1)
goal_weights = np.zeros((7, 1))
action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
# TODO remove this part, right now only needed for beer pong
mp_params, env_spec_params = self.env.episode_callback(action, self.traj_gen)

View File

@ -496,11 +496,11 @@ for _v in _versions:
kwargs_dict_box_pushing_prodmp['name'] = _v
kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.])
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1.
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.])
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1.
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 5
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0