delete unused argument
This commit is contained in:
parent
55df1e0ef6
commit
fe2d8fec91
@ -37,7 +37,7 @@ class TT_MPWrapper(RawInterfaceWrapper):
|
|||||||
|
|
||||||
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
|
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
|
||||||
return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
|
return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
|
||||||
return self.get_invalid_traj_step_return(action, pos_traj, vel_traj, return_contextual_obs)
|
return self.get_invalid_traj_step_return(action, pos_traj, return_contextual_obs)
|
||||||
|
|
||||||
class TTVelObs_MPWrapper(TT_MPWrapper):
|
class TTVelObs_MPWrapper(TT_MPWrapper):
|
||||||
|
|
||||||
|
@ -234,7 +234,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||||
return init_ball_state
|
return init_ball_state
|
||||||
|
|
||||||
def _get_traj_invalid_reward(self, action, pos_traj, vel_traj):
|
def _get_traj_invalid_reward(self, action, pos_traj):
|
||||||
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||||
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||||
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||||
@ -243,9 +243,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
violate_high_bound_error + violate_low_bound_error
|
violate_high_bound_error + violate_low_bound_error
|
||||||
return -invalid_penalty
|
return -invalid_penalty
|
||||||
|
|
||||||
def get_invalid_traj_step_return(self, action, pos_traj, vel_traj, contextual_obs):
|
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
|
||||||
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
||||||
penalty = self._get_traj_invalid_reward(action, pos_traj, vel_traj)
|
penalty = self._get_traj_invalid_reward(action, pos_traj)
|
||||||
return obs, penalty, True, {
|
return obs, penalty, True, {
|
||||||
"hit_ball": [False],
|
"hit_ball": [False],
|
||||||
"ball_return_success": [False],
|
"ball_return_success": [False],
|
||||||
|
Loading…
Reference in New Issue
Block a user