This commit is contained in:
Hongyi Zhou 2022-11-02 23:00:20 +01:00
parent a6cca617e1
commit 7b2451d317
3 changed files with 59 additions and 17 deletions

View File

@ -1,8 +1,8 @@
<mujocoinclude> <mujocoinclude>
<body name="target_ball" pos="0. 0. 0.1"> <body name="target_ball" pos="0. 0. 0.1">
<joint axis="1 0 0" damping="0.0" name="tar:x" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/> <joint axis="1 0 0" damping="0.0" name="tar_x" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
<joint axis="0 1 0" damping="0.0" name="tar:y" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/> <joint axis="0 1 0" damping="0.0" name="tar_y" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
<joint axis="0 0 1" damping="0.0" name="tar:z" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/> <joint axis="0 0 1" damping="0.0" name="tar_z" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
<geom size="0.025 0.025 0.025" type="sphere" condim="4" name="target_ball_contact" rgba="1 1 0 1" mass="0.1" <geom size="0.025 0.025 0.025" type="sphere" condim="4" name="target_ball_contact" rgba="1 1 0 1" mass="0.1"
friction="0.1 0.1 0.1" solimp="0.9 0.95 0.001 0.5 2" solref="0.1 0.03" priority="1"/> friction="0.1 0.1 0.1" solimp="0.9 0.95 0.001 0.5 2" solref="0.1 0.03" priority="1"/>
<site name="target_ball" pos="0 0 0" size="0.02 0.02 0.02" rgba="1 0 0 1" type="sphere"/> <site name="target_ball" pos="0 0 0" size="0.02 0.02 0.02" rgba="1 0 0 1" type="sphere"/>

View File

@ -12,8 +12,8 @@
<geom conaffinity="1" contype="1" material="floor_plane" name="floor" pos="0 0 0" size="10 5 1" type="plane" /> <geom conaffinity="1" contype="1" material="floor_plane" name="floor" pos="0 0 0" size="10 5 1" type="plane" />
<include file="include_table.xml" /> <include file="include_table.xml" />
<include file="include_barrett_wam_7dof_right.xml" /> <include file="include_barrett_wam_7dof_right.xml" />
<!-- <include file="include_target_ball.xml" />--> <include file="include_target_ball.xml" />
<include file="include_free_ball.xml" /> <!-- <include file="include_free_ball.xml" />-->
</worldbody> </worldbody>
<include file="include_7_motor_actuator.xml" /> <include file="include_7_motor_actuator.xml" />
<!-- <include file="right_arm_actuator.xml"/>--> <!-- <include file="right_arm_actuator.xml"/>-->

View File

@ -18,25 +18,33 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
7 DoF table tennis environment 7 DoF table tennis environment
""" """
def __init__(self, frame_skip: int = 4): def __init__(self, ctxt_dim: int = 2, frame_skip: int = 4):
utils.EzPickle.__init__(**locals()) utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
self.hit_ball = False self.hit_ball = False
self.ball_land_on_table = False self.ball_land_on_table = False
self._id_set = False self._id_set = False
self.ball_landing_pos = None
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
frame_skip=frame_skip, frame_skip=frame_skip,
mujoco_bindings="mujoco") mujoco_bindings="mujoco")
if ctxt_dim == 2:
self.context_bounds = CONTEXT_BOUNDS_2DIMS
elif ctxt_dim == 4:
self.context_bounds = CONTEXT_BOUNDS_4DIMS
else:
raise NotImplementedError
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32) self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
def _set_ids(self): def _set_ids(self):
self._floor_id = self.model.geom("floor").bodyid[0] self._floor_contact_id = self.model.geom("floor").bodyid[0]
self._ball_id = self.model.geom("target_ball_contact").bodyid[0] self._ball_contact_id = self.model.geom("target_ball_contact").bodyid[0]
self._bat_front_id = self.model.geom("bat").bodyid[0] self._bat_front_id = self.model.geom("bat").bodyid[0]
self._bat_back_id = self.model.geom("bat_back").bodyid[0] self._bat_back_id = self.model.geom("bat_back").bodyid[0]
self._table_id = self.model.geom("table_tennis_table").bodyid[0] self._table_contact_id = self.model.geom("table_tennis_table").bodyid[0]
self._id_set = True self._id_set = True
def step(self, action): def step(self, action):
@ -45,11 +53,33 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
unstable_simulation = False unstable_simulation = False
try: done = False
self.do_simulation(action, self.frame_skip)
except Exception as e: for _ in range(self.frame_skip):
print("Simulation get unstable return with MujocoException: ", e) try:
unstable_simulation = True self.do_simulation(action, self.frame_skip)
except Exception as e:
print("Simulation get unstable return with MujocoException: ", e)
unstable_simulation = True
if not self.hit_ball:
self.hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \
self._contact_checker(self._ball_contact_id, self._bat_back_id)
if not self.hit_ball:
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
if ball_land_on_floor_no_hit:
self.ball_landing_pos = self.data.body("target_ball").xpos.copy()
done = True
if self.hit_ball and not self.ball_contact_after_hit:
if not self.ball_contact_after_hit:
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
self.ball_contact_after_hit = True
self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy()
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
self.ball_contact_after_hit = True
self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy()
if self.ball_landing_pos[0] < 0.: # ball lands on the opponent side
self.ball_return_success = True
self._steps += 1 self._steps += 1
episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else False episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else False
@ -67,8 +97,19 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def reset_model(self): def reset_model(self):
self._steps = 0 self._steps = 0
new_context = self._sample_context()
self.data.joint("tar_x").qpos = new_context[0]
self.data.joint("tar_y").qpos = new_context[1]
self.data.joint("tar_z").qvel = 2.
self.ball_landing_pos = None
self.hit_ball = False
return self._get_obs() return self._get_obs()
def _sample_context(self):
return self.np_random.uniform(low=self.context_bounds[0],
high=self.context_bounds[1])
def _get_obs(self): def _get_obs(self):
obs = np.concatenate([ obs = np.concatenate([
self.data.qpos.flat[:7], self.data.qpos.flat[:7],
@ -80,6 +121,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
if __name__ == "__main__": if __name__ == "__main__":
env = TableTennisEnv() env = TableTennisEnv()
env.reset() env.reset()
while True: for _ in range(1000):
env.render("human") for _ in range(200):
env.step(env.action_space.sample()) env.render("human")
env.step(env.action_space.sample())