updates && disable auto_scale_basis for table tennis

This commit is contained in:
Hongyi Zhou 2022-11-23 17:02:04 +01:00
parent e3d36dead0
commit f47f00a292
2 changed files with 3 additions and 3 deletions

View File

@ -569,7 +569,7 @@ for _v in _versions:
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['weights_scale'] = 1.0 kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['weights_scale'] = 1.0
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0 kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = True kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = False
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0 kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_tau'] = True kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_tau'] = True
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True

View File

@ -151,8 +151,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def reset_model(self): def reset_model(self):
self._steps = 0 self._steps = 0
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False) self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False)
self._goal_pos = self._generate_goal_pos(random=True) self._goal_pos = self._generate_goal_pos(random=False)
self.data.joint("tar_x").qpos = self._init_ball_state[0] self.data.joint("tar_x").qpos = self._init_ball_state[0]
self.data.joint("tar_y").qpos = self._init_ball_state[1] self.data.joint("tar_y").qpos = self._init_ball_state[1]
self.data.joint("tar_z").qpos = self._init_ball_state[2] self.data.joint("tar_z").qpos = self._init_ball_state[2]