From 110a8a9c0c37c11554202bae8070ad9909c4f4ae Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 27 May 2023 12:55:46 +0200 Subject: [PATCH] Fix: MujocoEnv no longer supports manual assignment of mujoco_bindings --- .../mujoco/table_tennis/table_tennis_env.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 7fb5e9f..872aa75 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -22,6 +22,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): """ 7 DoF table tennis environment """ + def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, goal_switching_step: int = None, enable_artificial_wind: bool = False): @@ -52,9 +53,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), - frame_skip=frame_skip, - mujoco_bindings="mujoco") - + frame_skip=frame_skip,) + if ctxt_dim == 2: self.context_bounds = CONTEXT_BOUNDS_2DIMS elif ctxt_dim == 4: @@ -83,11 +83,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): unstable_simulation = False if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5: - new_goal_pos = self._generate_goal_pos(random=True) - new_goal_pos[1] = -new_goal_pos[1] - self._goal_pos = new_goal_pos - self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]]) - mujoco.mj_forward(self.model, self.data) + new_goal_pos = self._generate_goal_pos(random=True) + new_goal_pos[1] = -new_goal_pos[1] + self._goal_pos = new_goal_pos + self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]]) + mujoco.mj_forward(self.model, self.data) for _ in range(self.frame_skip): if self._enable_artificial_wind: @@ -102,7 +102,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): if not self._hit_ball: self._hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \ - self._contact_checker(self._ball_contact_id, self._bat_back_id) + self._contact_checker(self._ball_contact_id, self._bat_back_id) if not self._hit_ball: ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id) if ball_land_on_floor_no_hit: @@ -130,7 +130,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): reward = -25 if unstable_simulation else self._get_reward(self._terminated) land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \ - if self._ball_landing_pos is not None else 10. + if self._ball_landing_pos is not None else 10. return self._get_obs(), reward, self._terminated, { "hit_ball": self._hit_ball, @@ -202,7 +202,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): if not self._hit_ball: return 0.2 * (1 - np.tanh(min_r_b_dist**2)) if self._ball_landing_pos is None: - min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1)) + min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:, :2] - self._goal_pos[:2], axis=1)) return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2)) min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2]) over_net_bonus = int(self._ball_landing_pos[0] < 0) @@ -231,11 +231,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0)) violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0)) invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \ - violate_high_bound_error + violate_low_bound_error + violate_high_bound_error + violate_low_bound_error return -invalid_penalty def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs): - obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj + obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj penalty = self._get_traj_invalid_penalty(action, pos_traj) return obs, penalty, True, { "hit_ball": [False], @@ -249,7 +249,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): @staticmethod def check_traj_validity(action, pos_traj, vel_traj): time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \ - or action[1] > delay_bound[1] or action[1] < delay_bound[0] + or action[1] > delay_bound[1] or action[1] < delay_bound[0] if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low): return False, pos_traj, vel_traj return True, pos_traj, vel_traj