From 288ebedd50b076f4370bb09053dda54312d65d5f Mon Sep 17 00:00:00 2001 From: Hongyi Zhou Date: Wed, 26 Oct 2022 22:54:35 +0200 Subject: [PATCH] learn only basis weight --- fancy_gym/black_box/black_box_wrapper.py | 4 +++- fancy_gym/envs/__init__.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 908a664..6ba6cd9 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -150,7 +150,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): """ This function generates a trajectory based on a MP and then does the usual loop over reset and step""" ## tricky part, only use weights basis - weights_basis = action.reshape(-1, 7) + basis_weights = action.reshape(7, -1) + goal_weights = np.zeros((7, 1)) + action = np.concatenate((basis_weights, goal_weights), axis=1).flatten() # TODO remove this part, right now only needed for beer pong mp_params, env_spec_params = self.env.episode_callback(action, self.traj_gen) diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index e2bea2e..1ad261f 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -496,11 +496,11 @@ for _v in _versions: kwargs_dict_box_pushing_prodmp['name'] = _v kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) - # kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02]) - # kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01 - kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.]) - kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1. - kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4 + kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02]) + kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01 + # kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.]) + # kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1. + kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 5 kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10. kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4 kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0