learn only basis weight
This commit is contained in:
parent
e49d1563fe
commit
288ebedd50
@ -150,7 +150,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
||||||
|
|
||||||
## tricky part, only use weights basis
|
## tricky part, only use weights basis
|
||||||
weights_basis = action.reshape(-1, 7)
|
basis_weights = action.reshape(7, -1)
|
||||||
|
goal_weights = np.zeros((7, 1))
|
||||||
|
action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
|
||||||
|
|
||||||
# TODO remove this part, right now only needed for beer pong
|
# TODO remove this part, right now only needed for beer pong
|
||||||
mp_params, env_spec_params = self.env.episode_callback(action, self.traj_gen)
|
mp_params, env_spec_params = self.env.episode_callback(action, self.traj_gen)
|
||||||
|
@ -496,11 +496,11 @@ for _v in _versions:
|
|||||||
kwargs_dict_box_pushing_prodmp['name'] = _v
|
kwargs_dict_box_pushing_prodmp['name'] = _v
|
||||||
kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
||||||
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
||||||
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
|
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
|
||||||
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
|
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.])
|
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.])
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1.
|
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1.
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4
|
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 5
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user