delete unused argument

This commit is contained in:
Hongyi Zhou 2022-12-01 13:22:45 +01:00
parent 55df1e0ef6
commit fe2d8fec91
2 changed files with 4 additions and 4 deletions

View File

@ -37,7 +37,7 @@ class TT_MPWrapper(RawInterfaceWrapper):
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray, def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]: return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
return self.get_invalid_traj_step_return(action, pos_traj, vel_traj, return_contextual_obs) return self.get_invalid_traj_step_return(action, pos_traj, return_contextual_obs)
class TTVelObs_MPWrapper(TT_MPWrapper): class TTVelObs_MPWrapper(TT_MPWrapper):

View File

@ -234,7 +234,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
return init_ball_state return init_ball_state
def _get_traj_invalid_reward(self, action, pos_traj, vel_traj): def _get_traj_invalid_reward(self, action, pos_traj):
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])) delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0)) violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
@ -243,9 +243,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
violate_high_bound_error + violate_low_bound_error violate_high_bound_error + violate_low_bound_error
return -invalid_penalty return -invalid_penalty
def get_invalid_traj_step_return(self, action, pos_traj, vel_traj, contextual_obs): def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
penalty = self._get_traj_invalid_reward(action, pos_traj, vel_traj) penalty = self._get_traj_invalid_reward(action, pos_traj)
return obs, penalty, True, { return obs, penalty, True, {
"hit_ball": [False], "hit_ball": [False],
"ball_return_success": [False], "ball_return_success": [False],