updates
This commit is contained in:
parent
a6cca617e1
commit
7b2451d317
@ -1,8 +1,8 @@
|
|||||||
<mujocoinclude>
|
<mujocoinclude>
|
||||||
<body name="target_ball" pos="0. 0. 0.1">
|
<body name="target_ball" pos="0. 0. 0.1">
|
||||||
<joint axis="1 0 0" damping="0.0" name="tar:x" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
<joint axis="1 0 0" damping="0.0" name="tar_x" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
||||||
<joint axis="0 1 0" damping="0.0" name="tar:y" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
<joint axis="0 1 0" damping="0.0" name="tar_y" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
||||||
<joint axis="0 0 1" damping="0.0" name="tar:z" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
<joint axis="0 0 1" damping="0.0" name="tar_z" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
||||||
<geom size="0.025 0.025 0.025" type="sphere" condim="4" name="target_ball_contact" rgba="1 1 0 1" mass="0.1"
|
<geom size="0.025 0.025 0.025" type="sphere" condim="4" name="target_ball_contact" rgba="1 1 0 1" mass="0.1"
|
||||||
friction="0.1 0.1 0.1" solimp="0.9 0.95 0.001 0.5 2" solref="0.1 0.03" priority="1"/>
|
friction="0.1 0.1 0.1" solimp="0.9 0.95 0.001 0.5 2" solref="0.1 0.03" priority="1"/>
|
||||||
<site name="target_ball" pos="0 0 0" size="0.02 0.02 0.02" rgba="1 0 0 1" type="sphere"/>
|
<site name="target_ball" pos="0 0 0" size="0.02 0.02 0.02" rgba="1 0 0 1" type="sphere"/>
|
||||||
|
@ -12,8 +12,8 @@
|
|||||||
<geom conaffinity="1" contype="1" material="floor_plane" name="floor" pos="0 0 0" size="10 5 1" type="plane" />
|
<geom conaffinity="1" contype="1" material="floor_plane" name="floor" pos="0 0 0" size="10 5 1" type="plane" />
|
||||||
<include file="include_table.xml" />
|
<include file="include_table.xml" />
|
||||||
<include file="include_barrett_wam_7dof_right.xml" />
|
<include file="include_barrett_wam_7dof_right.xml" />
|
||||||
<!-- <include file="include_target_ball.xml" />-->
|
<include file="include_target_ball.xml" />
|
||||||
<include file="include_free_ball.xml" />
|
<!-- <include file="include_free_ball.xml" />-->
|
||||||
</worldbody>
|
</worldbody>
|
||||||
<include file="include_7_motor_actuator.xml" />
|
<include file="include_7_motor_actuator.xml" />
|
||||||
<!-- <include file="right_arm_actuator.xml"/>-->
|
<!-- <include file="right_arm_actuator.xml"/>-->
|
||||||
|
@ -18,25 +18,33 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
7 DoF table tennis environment
|
7 DoF table tennis environment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, frame_skip: int = 4):
|
def __init__(self, ctxt_dim: int = 2, frame_skip: int = 4):
|
||||||
utils.EzPickle.__init__(**locals())
|
utils.EzPickle.__init__(**locals())
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
self.hit_ball = False
|
self.hit_ball = False
|
||||||
self.ball_land_on_table = False
|
self.ball_land_on_table = False
|
||||||
self._id_set = False
|
self._id_set = False
|
||||||
|
self.ball_landing_pos = None
|
||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||||
frame_skip=frame_skip,
|
frame_skip=frame_skip,
|
||||||
mujoco_bindings="mujoco")
|
mujoco_bindings="mujoco")
|
||||||
|
if ctxt_dim == 2:
|
||||||
|
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
||||||
|
elif ctxt_dim == 4:
|
||||||
|
self.context_bounds = CONTEXT_BOUNDS_4DIMS
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
|
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
|
||||||
|
|
||||||
def _set_ids(self):
|
def _set_ids(self):
|
||||||
self._floor_id = self.model.geom("floor").bodyid[0]
|
self._floor_contact_id = self.model.geom("floor").bodyid[0]
|
||||||
self._ball_id = self.model.geom("target_ball_contact").bodyid[0]
|
self._ball_contact_id = self.model.geom("target_ball_contact").bodyid[0]
|
||||||
self._bat_front_id = self.model.geom("bat").bodyid[0]
|
self._bat_front_id = self.model.geom("bat").bodyid[0]
|
||||||
self._bat_back_id = self.model.geom("bat_back").bodyid[0]
|
self._bat_back_id = self.model.geom("bat_back").bodyid[0]
|
||||||
self._table_id = self.model.geom("table_tennis_table").bodyid[0]
|
self._table_contact_id = self.model.geom("table_tennis_table").bodyid[0]
|
||||||
self._id_set = True
|
self._id_set = True
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
@ -45,11 +53,33 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
unstable_simulation = False
|
unstable_simulation = False
|
||||||
|
|
||||||
try:
|
done = False
|
||||||
self.do_simulation(action, self.frame_skip)
|
|
||||||
except Exception as e:
|
for _ in range(self.frame_skip):
|
||||||
print("Simulation get unstable return with MujocoException: ", e)
|
try:
|
||||||
unstable_simulation = True
|
self.do_simulation(action, self.frame_skip)
|
||||||
|
except Exception as e:
|
||||||
|
print("Simulation get unstable return with MujocoException: ", e)
|
||||||
|
unstable_simulation = True
|
||||||
|
|
||||||
|
if not self.hit_ball:
|
||||||
|
self.hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \
|
||||||
|
self._contact_checker(self._ball_contact_id, self._bat_back_id)
|
||||||
|
if not self.hit_ball:
|
||||||
|
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
|
||||||
|
if ball_land_on_floor_no_hit:
|
||||||
|
self.ball_landing_pos = self.data.body("target_ball").xpos.copy()
|
||||||
|
done = True
|
||||||
|
if self.hit_ball and not self.ball_contact_after_hit:
|
||||||
|
if not self.ball_contact_after_hit:
|
||||||
|
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
|
||||||
|
self.ball_contact_after_hit = True
|
||||||
|
self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy()
|
||||||
|
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
|
||||||
|
self.ball_contact_after_hit = True
|
||||||
|
self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy()
|
||||||
|
if self.ball_landing_pos[0] < 0.: # ball lands on the opponent side
|
||||||
|
self.ball_return_success = True
|
||||||
|
|
||||||
self._steps += 1
|
self._steps += 1
|
||||||
episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else False
|
episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else False
|
||||||
@ -67,8 +97,19 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
new_context = self._sample_context()
|
||||||
|
self.data.joint("tar_x").qpos = new_context[0]
|
||||||
|
self.data.joint("tar_y").qpos = new_context[1]
|
||||||
|
self.data.joint("tar_z").qvel = 2.
|
||||||
|
|
||||||
|
self.ball_landing_pos = None
|
||||||
|
self.hit_ball = False
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
|
def _sample_context(self):
|
||||||
|
return self.np_random.uniform(low=self.context_bounds[0],
|
||||||
|
high=self.context_bounds[1])
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
obs = np.concatenate([
|
obs = np.concatenate([
|
||||||
self.data.qpos.flat[:7],
|
self.data.qpos.flat[:7],
|
||||||
@ -80,6 +121,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env = TableTennisEnv()
|
env = TableTennisEnv()
|
||||||
env.reset()
|
env.reset()
|
||||||
while True:
|
for _ in range(1000):
|
||||||
env.render("human")
|
for _ in range(200):
|
||||||
env.step(env.action_space.sample())
|
env.render("human")
|
||||||
|
env.step(env.action_space.sample())
|
||||||
|
Loading…
Reference in New Issue
Block a user