add replan name tag to replan envs & delete redundant settings

This commit is contained in:
Hongyi Zhou 2022-12-01 14:23:57 +01:00
parent 5744d339ac
commit 5750f6eb3d
3 changed files with 13 additions and 21 deletions

View File

@ -571,7 +571,7 @@ for _v in _versions:
for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
_env_id = f'{_name[0]}ReplanProDMP-{_name[1]}'
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
if _v == 'TableTennisWind-v0':
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TTVelObs_MPWrapper)
@ -580,8 +580,6 @@ for _v in _versions:
kwargs_dict_tt_prodmp['name'] = _v
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['weights_scale'] = 1.0
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = False
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
kwargs_dict_tt_prodmp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5]
@ -590,7 +588,7 @@ for _v in _versions:
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True
kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25.
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0

View File

@ -191,9 +191,6 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.data.joint("tar_x").qpos.copy(),
self.data.joint("tar_y").qpos.copy(),
self.data.joint("tar_z").qpos.copy(),
# self.data.joint("tar_x").qvel.copy(),
# self.data.joint("tar_y").qvel.copy(),
# self.data.joint("tar_z").qvel.copy(),
self._goal_pos.copy(),
])
return obs
@ -234,7 +231,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
return init_ball_state
def _get_traj_invalid_reward(self, action, pos_traj):
def _get_traj_invalid_penalty(self, action, pos_traj):
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
@ -245,7 +242,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
penalty = self._get_traj_invalid_reward(action, pos_traj)
penalty = self._get_traj_invalid_penalty(action, pos_traj)
return obs, penalty, True, {
"hit_ball": [False],
"ball_returned_success": [False],

View File

@ -155,25 +155,22 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__':
render = True
render = False
# DMP
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
# ProMP
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render)
example_mp("TableTennisWindProMP-v0", seed=10, iterations=20, render=render)
example_mp("TableTennisGoalSwitchingProMP-v0", seed=10, iterations=20, render=render)
# ProDMP with Replanning
# example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
# example_mp("TableTennis4DProDMP-v0", seed=10, iterations=100, render=render)
# example_mp("TableTennisWindProDMP-v0", seed=10, iterations=100, render=render)
# example_mp("TableTennisGoalSwitchingProDMP-v0", seed=10, iterations=100, render=render)
example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
example_mp("TableTennis4DReplanProDMP-v0", seed=10, iterations=20, render=render)
example_mp("TableTennisWindReplanProDMP-v0", seed=10, iterations=20, render=render)
# Altered basis functions
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
# Custom MP
# example_fully_custom_mp(seed=10, iterations=1, render=render)
example_fully_custom_mp(seed=10, iterations=1, render=render)