add replan name tag to replan envs & delete redundant settings
This commit is contained in:
parent
5744d339ac
commit
5750f6eb3d
@ -571,7 +571,7 @@ for _v in _versions:
|
||||
|
||||
for _v in _versions:
|
||||
_name = _v.split("-")
|
||||
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
|
||||
_env_id = f'{_name[0]}ReplanProDMP-{_name[1]}'
|
||||
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||
if _v == 'TableTennisWind-v0':
|
||||
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TTVelObs_MPWrapper)
|
||||
@ -580,8 +580,6 @@ for _v in _versions:
|
||||
kwargs_dict_tt_prodmp['name'] = _v
|
||||
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
|
||||
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
||||
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['weights_scale'] = 1.0
|
||||
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0
|
||||
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = False
|
||||
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
|
||||
kwargs_dict_tt_prodmp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5]
|
||||
@ -590,7 +588,7 @@ for _v in _versions:
|
||||
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True
|
||||
kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2
|
||||
kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25.
|
||||
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
|
||||
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
|
||||
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
||||
kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3
|
||||
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0
|
||||
|
@ -191,9 +191,6 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
self.data.joint("tar_x").qpos.copy(),
|
||||
self.data.joint("tar_y").qpos.copy(),
|
||||
self.data.joint("tar_z").qpos.copy(),
|
||||
# self.data.joint("tar_x").qvel.copy(),
|
||||
# self.data.joint("tar_y").qvel.copy(),
|
||||
# self.data.joint("tar_z").qvel.copy(),
|
||||
self._goal_pos.copy(),
|
||||
])
|
||||
return obs
|
||||
@ -234,7 +231,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||
return init_ball_state
|
||||
|
||||
def _get_traj_invalid_reward(self, action, pos_traj):
|
||||
def _get_traj_invalid_penalty(self, action, pos_traj):
|
||||
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||
@ -245,7 +242,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
|
||||
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
||||
penalty = self._get_traj_invalid_reward(action, pos_traj)
|
||||
penalty = self._get_traj_invalid_penalty(action, pos_traj)
|
||||
return obs, penalty, True, {
|
||||
"hit_ball": [False],
|
||||
"ball_returned_success": [False],
|
||||
|
@ -155,25 +155,22 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = True
|
||||
render = False
|
||||
# DMP
|
||||
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||
|
||||
# ProMP
|
||||
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||
example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render)
|
||||
example_mp("TableTennisWindProMP-v0", seed=10, iterations=20, render=render)
|
||||
example_mp("TableTennisGoalSwitchingProMP-v0", seed=10, iterations=20, render=render)
|
||||
|
||||
# ProDMP with Replanning
|
||||
# example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
|
||||
# example_mp("TableTennis4DProDMP-v0", seed=10, iterations=100, render=render)
|
||||
# example_mp("TableTennisWindProDMP-v0", seed=10, iterations=100, render=render)
|
||||
# example_mp("TableTennisGoalSwitchingProDMP-v0", seed=10, iterations=100, render=render)
|
||||
example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
|
||||
example_mp("TableTennis4DReplanProDMP-v0", seed=10, iterations=20, render=render)
|
||||
example_mp("TableTennisWindReplanProDMP-v0", seed=10, iterations=20, render=render)
|
||||
|
||||
# Altered basis functions
|
||||
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||
|
||||
# Custom MP
|
||||
# example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||
example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||
|
Loading…
Reference in New Issue
Block a user