diff --git a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py index bd7cbf7..e33ed6c 100644 --- a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py +++ b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py @@ -37,7 +37,7 @@ class TT_MPWrapper(RawInterfaceWrapper): def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray, return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]: - return self.get_invalid_traj_step_return(action, pos_traj, vel_traj, return_contextual_obs) + return self.get_invalid_traj_step_return(action, pos_traj, return_contextual_obs) class TTVelObs_MPWrapper(TT_MPWrapper): diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 2599006..734588a 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -234,7 +234,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) return init_ball_state - def _get_traj_invalid_reward(self, action, pos_traj, vel_traj): + def _get_traj_invalid_reward(self, action, pos_traj): tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])) violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0)) @@ -243,9 +243,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): violate_high_bound_error + violate_low_bound_error return -invalid_penalty - def get_invalid_traj_step_return(self, action, pos_traj, vel_traj, contextual_obs): + def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs): obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj - penalty = self._get_traj_invalid_reward(action, pos_traj, vel_traj) + penalty = self._get_traj_invalid_reward(action, pos_traj) return obs, penalty, True, { "hit_ball": [False], "ball_return_success": [False],