Fix: MujocoEnv no longer supports manual assignment of mujoco_bindings

This commit is contained in:
Dominik Moritz Roth 2023-05-27 12:55:46 +02:00
parent dbd7c37da5
commit 110a8a9c0c

View File

@ -22,6 +22,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
""" """
7 DoF table tennis environment 7 DoF table tennis environment
""" """
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
goal_switching_step: int = None, goal_switching_step: int = None,
enable_artificial_wind: bool = False): enable_artificial_wind: bool = False):
@ -52,9 +53,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
frame_skip=frame_skip, frame_skip=frame_skip,)
mujoco_bindings="mujoco")
if ctxt_dim == 2: if ctxt_dim == 2:
self.context_bounds = CONTEXT_BOUNDS_2DIMS self.context_bounds = CONTEXT_BOUNDS_2DIMS
elif ctxt_dim == 4: elif ctxt_dim == 4:
@ -83,11 +83,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
unstable_simulation = False unstable_simulation = False
if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5: if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5:
new_goal_pos = self._generate_goal_pos(random=True) new_goal_pos = self._generate_goal_pos(random=True)
new_goal_pos[1] = -new_goal_pos[1] new_goal_pos[1] = -new_goal_pos[1]
self._goal_pos = new_goal_pos self._goal_pos = new_goal_pos
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]]) self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
mujoco.mj_forward(self.model, self.data) mujoco.mj_forward(self.model, self.data)
for _ in range(self.frame_skip): for _ in range(self.frame_skip):
if self._enable_artificial_wind: if self._enable_artificial_wind:
@ -102,7 +102,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
if not self._hit_ball: if not self._hit_ball:
self._hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \ self._hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \
self._contact_checker(self._ball_contact_id, self._bat_back_id) self._contact_checker(self._ball_contact_id, self._bat_back_id)
if not self._hit_ball: if not self._hit_ball:
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id) ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
if ball_land_on_floor_no_hit: if ball_land_on_floor_no_hit:
@ -130,7 +130,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
reward = -25 if unstable_simulation else self._get_reward(self._terminated) reward = -25 if unstable_simulation else self._get_reward(self._terminated)
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \ land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
if self._ball_landing_pos is not None else 10. if self._ball_landing_pos is not None else 10.
return self._get_obs(), reward, self._terminated, { return self._get_obs(), reward, self._terminated, {
"hit_ball": self._hit_ball, "hit_ball": self._hit_ball,
@ -202,7 +202,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
if not self._hit_ball: if not self._hit_ball:
return 0.2 * (1 - np.tanh(min_r_b_dist**2)) return 0.2 * (1 - np.tanh(min_r_b_dist**2))
if self._ball_landing_pos is None: if self._ball_landing_pos is None:
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1)) min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:, :2] - self._goal_pos[:2], axis=1))
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2)) return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2]) min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
over_net_bonus = int(self._ball_landing_pos[0] < 0) over_net_bonus = int(self._ball_landing_pos[0] < 0)
@ -231,11 +231,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0)) violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0)) violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \ invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
violate_high_bound_error + violate_low_bound_error violate_high_bound_error + violate_low_bound_error
return -invalid_penalty return -invalid_penalty
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs): def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
penalty = self._get_traj_invalid_penalty(action, pos_traj) penalty = self._get_traj_invalid_penalty(action, pos_traj)
return obs, penalty, True, { return obs, penalty, True, {
"hit_ball": [False], "hit_ball": [False],
@ -249,7 +249,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
@staticmethod @staticmethod
def check_traj_validity(action, pos_traj, vel_traj): def check_traj_validity(action, pos_traj, vel_traj):
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \ time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
or action[1] > delay_bound[1] or action[1] < delay_bound[0] or action[1] > delay_bound[1] or action[1] < delay_bound[0]
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low): if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
return False, pos_traj, vel_traj return False, pos_traj, vel_traj
return True, pos_traj, vel_traj return True, pos_traj, vel_traj