delete unused argument
This commit is contained in:
parent
55df1e0ef6
commit
fe2d8fec91
@ -37,7 +37,7 @@ class TT_MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
|
||||
return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
|
||||
return self.get_invalid_traj_step_return(action, pos_traj, vel_traj, return_contextual_obs)
|
||||
return self.get_invalid_traj_step_return(action, pos_traj, return_contextual_obs)
|
||||
|
||||
class TTVelObs_MPWrapper(TT_MPWrapper):
|
||||
|
||||
|
@ -234,7 +234,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||
return init_ball_state
|
||||
|
||||
def _get_traj_invalid_reward(self, action, pos_traj, vel_traj):
|
||||
def _get_traj_invalid_reward(self, action, pos_traj):
|
||||
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||
@ -243,9 +243,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
violate_high_bound_error + violate_low_bound_error
|
||||
return -invalid_penalty
|
||||
|
||||
def get_invalid_traj_step_return(self, action, pos_traj, vel_traj, contextual_obs):
|
||||
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
|
||||
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
||||
penalty = self._get_traj_invalid_reward(action, pos_traj, vel_traj)
|
||||
penalty = self._get_traj_invalid_reward(action, pos_traj)
|
||||
return obs, penalty, True, {
|
||||
"hit_ball": [False],
|
||||
"ball_return_success": [False],
|
||||
|
Loading…
Reference in New Issue
Block a user