diff --git a/fancy_gym/envs/mujoco/table_tennis/assets/xml/include_target_ball.xml b/fancy_gym/envs/mujoco/table_tennis/assets/xml/include_target_ball.xml index feb9125..bf77c0f 100644 --- a/fancy_gym/envs/mujoco/table_tennis/assets/xml/include_target_ball.xml +++ b/fancy_gym/envs/mujoco/table_tennis/assets/xml/include_target_ball.xml @@ -1,8 +1,8 @@ - - - + + + diff --git a/fancy_gym/envs/mujoco/table_tennis/assets/xml/table_tennis_env.xml b/fancy_gym/envs/mujoco/table_tennis/assets/xml/table_tennis_env.xml index afacf37..8c2aba3 100644 --- a/fancy_gym/envs/mujoco/table_tennis/assets/xml/table_tennis_env.xml +++ b/fancy_gym/envs/mujoco/table_tennis/assets/xml/table_tennis_env.xml @@ -12,8 +12,8 @@ - - + + diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 86ce595..fdde6f5 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -18,25 +18,33 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): 7 DoF table tennis environment """ - def __init__(self, frame_skip: int = 4): + def __init__(self, ctxt_dim: int = 2, frame_skip: int = 4): utils.EzPickle.__init__(**locals()) self._steps = 0 self.hit_ball = False self.ball_land_on_table = False self._id_set = False + self.ball_landing_pos = None MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), frame_skip=frame_skip, mujoco_bindings="mujoco") + if ctxt_dim == 2: + self.context_bounds = CONTEXT_BOUNDS_2DIMS + elif ctxt_dim == 4: + self.context_bounds = CONTEXT_BOUNDS_4DIMS + else: + raise NotImplementedError + self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32) def _set_ids(self): - self._floor_id = self.model.geom("floor").bodyid[0] - self._ball_id = self.model.geom("target_ball_contact").bodyid[0] + self._floor_contact_id = self.model.geom("floor").bodyid[0] + self._ball_contact_id = self.model.geom("target_ball_contact").bodyid[0] self._bat_front_id = self.model.geom("bat").bodyid[0] self._bat_back_id = self.model.geom("bat_back").bodyid[0] - self._table_id = self.model.geom("table_tennis_table").bodyid[0] + self._table_contact_id = self.model.geom("table_tennis_table").bodyid[0] self._id_set = True def step(self, action): @@ -45,11 +53,33 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): unstable_simulation = False - try: - self.do_simulation(action, self.frame_skip) - except Exception as e: - print("Simulation get unstable return with MujocoException: ", e) - unstable_simulation = True + done = False + + for _ in range(self.frame_skip): + try: + self.do_simulation(action, self.frame_skip) + except Exception as e: + print("Simulation get unstable return with MujocoException: ", e) + unstable_simulation = True + + if not self.hit_ball: + self.hit_ball = self._contact_checker(self._ball_contact_id, self._bat_front_id) or \ + self._contact_checker(self._ball_contact_id, self._bat_back_id) + if not self.hit_ball: + ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id) + if ball_land_on_floor_no_hit: + self.ball_landing_pos = self.data.body("target_ball").xpos.copy() + done = True + if self.hit_ball and not self.ball_contact_after_hit: + if not self.ball_contact_after_hit: + if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor + self.ball_contact_after_hit = True + self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy() + elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table + self.ball_contact_after_hit = True + self.ball_landing_pos = self.sim.data.geom("target_ball_contact").xpos.copy() + if self.ball_landing_pos[0] < 0.: # ball lands on the opponent side + self.ball_return_success = True self._steps += 1 episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else False @@ -67,8 +97,19 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): def reset_model(self): self._steps = 0 + new_context = self._sample_context() + self.data.joint("tar_x").qpos = new_context[0] + self.data.joint("tar_y").qpos = new_context[1] + self.data.joint("tar_z").qvel = 2. + + self.ball_landing_pos = None + self.hit_ball = False return self._get_obs() + def _sample_context(self): + return self.np_random.uniform(low=self.context_bounds[0], + high=self.context_bounds[1]) + def _get_obs(self): obs = np.concatenate([ self.data.qpos.flat[:7], @@ -80,6 +121,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): if __name__ == "__main__": env = TableTennisEnv() env.reset() - while True: - env.render("human") - env.step(env.action_space.sample()) + for _ in range(1000): + for _ in range(200): + env.render("human") + env.step(env.action_space.sample())