Fix: MujocoEnv no longer supports manual assignment of mujoco_bindings
This commit is contained in:
parent
dbd7c37da5
commit
110a8a9c0c
@ -22,6 +22,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
"""
|
||||
7 DoF table tennis environment
|
||||
"""
|
||||
|
||||
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
||||
goal_switching_step: int = None,
|
||||
enable_artificial_wind: bool = False):
|
||||
@ -52,9 +53,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
MujocoEnv.__init__(self,
|
||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||
frame_skip=frame_skip,
|
||||
mujoco_bindings="mujoco")
|
||||
|
||||
frame_skip=frame_skip,)
|
||||
|
||||
if ctxt_dim == 2:
|
||||
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
||||
elif ctxt_dim == 4:
|
||||
@ -83,11 +83,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
unstable_simulation = False
|
||||
|
||||
if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5:
|
||||
new_goal_pos = self._generate_goal_pos(random=True)
|
||||
new_goal_pos[1] = -new_goal_pos[1]
|
||||
self._goal_pos = new_goal_pos
|
||||
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
||||
mujoco.mj_forward(self.model, self.data)
|
||||
new_goal_pos = self._generate_goal_pos(random=True)
|
||||
new_goal_pos[1] = -new_goal_pos[1]
|
||||
self._goal_pos = new_goal_pos
|
||||
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
||||
mujoco.mj_forward(self.model, self.data)
|
||||
|
||||
for _ in range(self.frame_skip):
|
||||
if self._enable_artificial_wind:
|
||||
@ -102,7 +102,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
if not self._hit_ball:
|
||||
self._hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \
|
||||
self._contact_checker(self._ball_contact_id, self._bat_back_id)
|
||||
self._contact_checker(self._ball_contact_id, self._bat_back_id)
|
||||
if not self._hit_ball:
|
||||
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
|
||||
if ball_land_on_floor_no_hit:
|
||||
@ -130,7 +130,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
reward = -25 if unstable_simulation else self._get_reward(self._terminated)
|
||||
|
||||
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
|
||||
if self._ball_landing_pos is not None else 10.
|
||||
if self._ball_landing_pos is not None else 10.
|
||||
|
||||
return self._get_obs(), reward, self._terminated, {
|
||||
"hit_ball": self._hit_ball,
|
||||
@ -202,7 +202,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
if not self._hit_ball:
|
||||
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
|
||||
if self._ball_landing_pos is None:
|
||||
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
|
||||
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:, :2] - self._goal_pos[:2], axis=1))
|
||||
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
|
||||
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
|
||||
over_net_bonus = int(self._ball_landing_pos[0] < 0)
|
||||
@ -231,11 +231,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
|
||||
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
||||
violate_high_bound_error + violate_low_bound_error
|
||||
violate_high_bound_error + violate_low_bound_error
|
||||
return -invalid_penalty
|
||||
|
||||
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
|
||||
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
||||
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
||||
penalty = self._get_traj_invalid_penalty(action, pos_traj)
|
||||
return obs, penalty, True, {
|
||||
"hit_ball": [False],
|
||||
@ -249,7 +249,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
@staticmethod
|
||||
def check_traj_validity(action, pos_traj, vel_traj):
|
||||
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
||||
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
|
||||
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
|
||||
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
|
||||
return False, pos_traj, vel_traj
|
||||
return True, pos_traj, vel_traj
|
||||
|
Loading…
Reference in New Issue
Block a user