Fix: MujocoEnv no longer supports manual assignment of mujoco_bindings
This commit is contained in:
parent
dbd7c37da5
commit
110a8a9c0c
@ -22,6 +22,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
"""
|
"""
|
||||||
7 DoF table tennis environment
|
7 DoF table tennis environment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
||||||
goal_switching_step: int = None,
|
goal_switching_step: int = None,
|
||||||
enable_artificial_wind: bool = False):
|
enable_artificial_wind: bool = False):
|
||||||
@ -52,8 +53,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||||
frame_skip=frame_skip,
|
frame_skip=frame_skip,)
|
||||||
mujoco_bindings="mujoco")
|
|
||||||
|
|
||||||
if ctxt_dim == 2:
|
if ctxt_dim == 2:
|
||||||
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
||||||
@ -83,11 +83,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
unstable_simulation = False
|
unstable_simulation = False
|
||||||
|
|
||||||
if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5:
|
if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5:
|
||||||
new_goal_pos = self._generate_goal_pos(random=True)
|
new_goal_pos = self._generate_goal_pos(random=True)
|
||||||
new_goal_pos[1] = -new_goal_pos[1]
|
new_goal_pos[1] = -new_goal_pos[1]
|
||||||
self._goal_pos = new_goal_pos
|
self._goal_pos = new_goal_pos
|
||||||
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
||||||
mujoco.mj_forward(self.model, self.data)
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
for _ in range(self.frame_skip):
|
for _ in range(self.frame_skip):
|
||||||
if self._enable_artificial_wind:
|
if self._enable_artificial_wind:
|
||||||
@ -102,7 +102,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
if not self._hit_ball:
|
if not self._hit_ball:
|
||||||
self._hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \
|
self._hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \
|
||||||
self._contact_checker(self._ball_contact_id, self._bat_back_id)
|
self._contact_checker(self._ball_contact_id, self._bat_back_id)
|
||||||
if not self._hit_ball:
|
if not self._hit_ball:
|
||||||
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
|
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
|
||||||
if ball_land_on_floor_no_hit:
|
if ball_land_on_floor_no_hit:
|
||||||
@ -130,7 +130,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
reward = -25 if unstable_simulation else self._get_reward(self._terminated)
|
reward = -25 if unstable_simulation else self._get_reward(self._terminated)
|
||||||
|
|
||||||
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
|
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
|
||||||
if self._ball_landing_pos is not None else 10.
|
if self._ball_landing_pos is not None else 10.
|
||||||
|
|
||||||
return self._get_obs(), reward, self._terminated, {
|
return self._get_obs(), reward, self._terminated, {
|
||||||
"hit_ball": self._hit_ball,
|
"hit_ball": self._hit_ball,
|
||||||
@ -202,7 +202,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
if not self._hit_ball:
|
if not self._hit_ball:
|
||||||
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
|
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
|
||||||
if self._ball_landing_pos is None:
|
if self._ball_landing_pos is None:
|
||||||
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
|
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:, :2] - self._goal_pos[:2], axis=1))
|
||||||
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
|
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
|
||||||
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
|
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
|
||||||
over_net_bonus = int(self._ball_landing_pos[0] < 0)
|
over_net_bonus = int(self._ball_landing_pos[0] < 0)
|
||||||
@ -231,11 +231,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||||
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
|
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
|
||||||
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
||||||
violate_high_bound_error + violate_low_bound_error
|
violate_high_bound_error + violate_low_bound_error
|
||||||
return -invalid_penalty
|
return -invalid_penalty
|
||||||
|
|
||||||
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
|
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
|
||||||
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
||||||
penalty = self._get_traj_invalid_penalty(action, pos_traj)
|
penalty = self._get_traj_invalid_penalty(action, pos_traj)
|
||||||
return obs, penalty, True, {
|
return obs, penalty, True, {
|
||||||
"hit_ball": [False],
|
"hit_ball": [False],
|
||||||
@ -249,7 +249,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def check_traj_validity(action, pos_traj, vel_traj):
|
def check_traj_validity(action, pos_traj, vel_traj):
|
||||||
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
||||||
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
|
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
|
||||||
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
|
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
|
||||||
return False, pos_traj, vel_traj
|
return False, pos_traj, vel_traj
|
||||||
return True, pos_traj, vel_traj
|
return True, pos_traj, vel_traj
|
||||||
|
Loading…
Reference in New Issue
Block a user