updates
This commit is contained in:
parent
a6cca617e1
commit
7b2451d317
@ -1,8 +1,8 @@
|
||||
<mujocoinclude>
|
||||
<body name="target_ball" pos="0. 0. 0.1">
|
||||
<joint axis="1 0 0" damping="0.0" name="tar:x" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
||||
<joint axis="0 1 0" damping="0.0" name="tar:y" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
||||
<joint axis="0 0 1" damping="0.0" name="tar:z" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
||||
<joint axis="1 0 0" damping="0.0" name="tar_x" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
||||
<joint axis="0 1 0" damping="0.0" name="tar_y" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
||||
<joint axis="0 0 1" damping="0.0" name="tar_z" pos="0 0 0" stiffness="0" type="slide" frictionloss="0"/>
|
||||
<geom size="0.025 0.025 0.025" type="sphere" condim="4" name="target_ball_contact" rgba="1 1 0 1" mass="0.1"
|
||||
friction="0.1 0.1 0.1" solimp="0.9 0.95 0.001 0.5 2" solref="0.1 0.03" priority="1"/>
|
||||
<site name="target_ball" pos="0 0 0" size="0.02 0.02 0.02" rgba="1 0 0 1" type="sphere"/>
|
||||
|
@ -12,8 +12,8 @@
|
||||
<geom conaffinity="1" contype="1" material="floor_plane" name="floor" pos="0 0 0" size="10 5 1" type="plane" />
|
||||
<include file="include_table.xml" />
|
||||
<include file="include_barrett_wam_7dof_right.xml" />
|
||||
<!-- <include file="include_target_ball.xml" />-->
|
||||
<include file="include_free_ball.xml" />
|
||||
<include file="include_target_ball.xml" />
|
||||
<!-- <include file="include_free_ball.xml" />-->
|
||||
</worldbody>
|
||||
<include file="include_7_motor_actuator.xml" />
|
||||
<!-- <include file="right_arm_actuator.xml"/>-->
|
||||
|
@ -18,25 +18,33 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
7 DoF table tennis environment
|
||||
"""
|
||||
|
||||
def __init__(self, frame_skip: int = 4):
|
||||
def __init__(self, ctxt_dim: int = 2, frame_skip: int = 4):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
self._steps = 0
|
||||
|
||||
self.hit_ball = False
|
||||
self.ball_land_on_table = False
|
||||
self._id_set = False
|
||||
self.ball_landing_pos = None
|
||||
MujocoEnv.__init__(self,
|
||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||
frame_skip=frame_skip,
|
||||
mujoco_bindings="mujoco")
|
||||
if ctxt_dim == 2:
|
||||
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
||||
elif ctxt_dim == 4:
|
||||
self.context_bounds = CONTEXT_BOUNDS_4DIMS
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
|
||||
|
||||
def _set_ids(self):
|
||||
self._floor_id = self.model.geom("floor").bodyid[0]
|
||||
self._ball_id = self.model.geom("target_ball_contact").bodyid[0]
|
||||
self._floor_contact_id = self.model.geom("floor").bodyid[0]
|
||||
self._ball_contact_id = self.model.geom("target_ball_contact").bodyid[0]
|
||||
self._bat_front_id = self.model.geom("bat").bodyid[0]
|
||||
self._bat_back_id = self.model.geom("bat_back").bodyid[0]
|
||||
self._table_id = self.model.geom("table_tennis_table").bodyid[0]
|
||||
self._table_contact_id = self.model.geom("table_tennis_table").bodyid[0]
|
||||
self._id_set = True
|
||||
|
||||
def step(self, action):
|
||||
@ -45,11 +53,33 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
unstable_simulation = False
|
||||
|
||||
try:
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
except Exception as e:
|
||||
print("Simulation get unstable return with MujocoException: ", e)
|
||||
unstable_simulation = True
|
||||
done = False
|
||||
|
||||
for _ in range(self.frame_skip):
|
||||
try:
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
except Exception as e:
|
||||
print("Simulation get unstable return with MujocoException: ", e)
|
||||
unstable_simulation = True
|
||||
|
||||
if not self.hit_ball:
|
||||
self.hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \
|
||||
self._contact_checker(self._ball_contact_id, self._bat_back_id)
|
||||
if not self.hit_ball:
|
||||
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
|
||||
if ball_land_on_floor_no_hit:
|
||||
self.ball_landing_pos = self.data.body("target_ball").xpos.copy()
|
||||
done = True
|
||||
if self.hit_ball and not self.ball_contact_after_hit:
|
||||
if not self.ball_contact_after_hit:
|
||||
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
|
||||
self.ball_contact_after_hit = True
|
||||
self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy()
|
||||
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
|
||||
self.ball_contact_after_hit = True
|
||||
self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy()
|
||||
if self.ball_landing_pos[0] < 0.: # ball lands on the opponent side
|
||||
self.ball_return_success = True
|
||||
|
||||
self._steps += 1
|
||||
episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else False
|
||||
@ -67,8 +97,19 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
def reset_model(self):
|
||||
self._steps = 0
|
||||
new_context = self._sample_context()
|
||||
self.data.joint("tar_x").qpos = new_context[0]
|
||||
self.data.joint("tar_y").qpos = new_context[1]
|
||||
self.data.joint("tar_z").qvel = 2.
|
||||
|
||||
self.ball_landing_pos = None
|
||||
self.hit_ball = False
|
||||
return self._get_obs()
|
||||
|
||||
def _sample_context(self):
|
||||
return self.np_random.uniform(low=self.context_bounds[0],
|
||||
high=self.context_bounds[1])
|
||||
|
||||
def _get_obs(self):
|
||||
obs = np.concatenate([
|
||||
self.data.qpos.flat[:7],
|
||||
@ -80,6 +121,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
if __name__ == "__main__":
|
||||
env = TableTennisEnv()
|
||||
env.reset()
|
||||
while True:
|
||||
env.render("human")
|
||||
env.step(env.action_space.sample())
|
||||
for _ in range(1000):
|
||||
for _ in range(200):
|
||||
env.render("human")
|
||||
env.step(env.action_space.sample())
|
||||
|
Loading…
Reference in New Issue
Block a user