minor updates

This commit is contained in:
Hongyi Zhou 2022-11-14 17:48:15 +01:00
parent be14b21fff
commit fc3051bf57
3 changed files with 7 additions and 10 deletions

View File

@ -86,7 +86,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
return observation.astype(self.observation_space.dtype)
def get_trajectory(self, action: np.ndarray) -> Tuple:
# duration = self.duration - self.current_traj_steps * self.dt
duration = self.duration
if self.learn_sub_trajectories:
duration = None

View File

@ -498,15 +498,13 @@ for _v in _versions:
kwargs_dict_box_pushing_prodmp['name'] = _v
kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = True
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 0
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0

View File

@ -157,17 +157,17 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__':
render = True
# DMP
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
# ProMP
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
# ProDMP
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=4, render=render)
# Altered basis functions
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
# Custom MP
# example_fully_custom_mp(seed=10, iterations=1, render=render)
example_fully_custom_mp(seed=10, iterations=1, render=render)